Skip to content

Commit ca008eb

Browse files
authored
Clean up IR.py (#1084)
## Description - Remove torch version checking, as a result, remove dependency on packaging - Pack multi-use param logics into functions - Remove duplicated `_split_before_forwad` and `_split_before_backwad`
1 parent a973821 commit ca008eb

File tree

3 files changed

+142
-140
lines changed

3 files changed

+142
-140
lines changed

pippy/IR.py

+142-138
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212
import torch.fx as fx
13-
from packaging import version
1413
from torch.export import ExportedProgram
1514
from torch.fx.node import map_aggregate
1615
from torch.fx.passes.split_module import split_module
@@ -29,22 +28,9 @@
2928

3029
logger = logging.getLogger(__name__)
3130

32-
# Because split_module with 4 arguments is available only in PT 1.12+
33-
TORCH_FX_REQUIRED_VERSION = version.parse("1.12")
34-
35-
torch_version = version.parse(torch.__version__)
36-
assert (torch_version.major, torch_version.minor) >= ( # type: ignore
37-
TORCH_FX_REQUIRED_VERSION.major, # type: ignore
38-
TORCH_FX_REQUIRED_VERSION.minor, # type: ignore
39-
), "PiPPy requires PyTorch >= 1.12"
40-
4131
# TODO:
4232
# 1. investigate gradient sync for shared parameters. how does DDP do it?
43-
# 2. Modify serialization to put parameters back in their original qualname form?
44-
# 3. Shape specialized tracing?
45-
# 4. Can we define semantics for shared module call? Can we make this configurable in the same way as
46-
# with shared parameters? probably need to modify split_module in this case
47-
# 5. Add parameter movement to split_module
33+
# 2. Add parameter movement to split_module
4834

4935

5036
def _find_loss_from_output_and_spec(output_val, spec_val):
@@ -931,134 +917,164 @@ def move_param_to_callee(
931917
if node.op == "get_attr" and len(node.users) > 1:
932918
multi_use_params_qualnames.setdefault(node.target)
933919

934-
# TODO: re-enable support for multi-use parameters
935-
assert len(multi_use_params_qualnames) == 0
936-
937-
"""
938-
for param in multi_use_params_qualnames:
939-
if isinstance(multi_use_param_spec, MultiUseParameterConfig):
940-
multi_use_params_qualnames[param] = multi_use_param_spec
941-
elif isinstance(multi_use_param_spec, dict):
942-
multi_use_params_qualnames[param] = multi_use_param_spec.get(
943-
param, MultiUseParameterConfig.TRANSMIT
944-
)
945-
else:
946-
raise ValueError(
947-
"multi_use_param_spec must be MultiUseParamSpec enum or dict"
948-
)
949-
950-
# TODO: do we maintain the invariant that `Node.users` is topologically ordered? I don't think so
951-
node_to_first_user: Dict[fx.Node, fx.Node] = {}
952-
for node in split.graph.nodes:
953-
for input in node.all_input_nodes:
954-
if input not in node_to_first_user:
955-
node_to_first_user[input] = node
956-
957-
for node in split.graph.nodes:
958-
if (
959-
node.op == "get_attr"
960-
and node.target in multi_use_params_qualnames
961-
):
962-
reuse_type = multi_use_params_qualnames[node.target]
963-
if reuse_type == MultiUseParameterConfig.TRANSMIT:
964-
first_user = node_to_first_user[node]
965-
assert first_user.op == "call_module"
966-
967-
use_idx = delete_user_reference(
968-
node, first_user, delete_node=False
920+
def set_multi_use_param_spec(
921+
multi_use_params_qualnames,
922+
multi_use_param_spec,
923+
):
924+
for param in multi_use_params_qualnames:
925+
if isinstance(multi_use_param_spec, MultiUseParameterConfig):
926+
multi_use_params_qualnames[param] = multi_use_param_spec
927+
elif isinstance(multi_use_param_spec, dict):
928+
multi_use_params_qualnames[
929+
param
930+
] = multi_use_param_spec.get(
931+
param, MultiUseParameterConfig.TRANSMIT
969932
)
970-
971-
atoms = node.target.split(".")
972-
mod_itr = split
973-
for atom in atoms[:-1]:
974-
mod_itr = getattr(mod_itr, atom)
975-
param_val = getattr(mod_itr, atoms[-1])
976-
is_buffer = atoms[-1] in mod_itr._buffers
977-
978-
callee_param_def = move_param_to_callee(
979-
split, first_user.target, param_val, use_idx, is_buffer
933+
else:
934+
raise ValueError(
935+
"multi_use_param_spec must be MultiUseParamSpec enum or dict"
980936
)
981937

982-
delattr(mod_itr, atoms[-1])
983-
984-
# Add extra output to the callee and switch references to the parameter
985-
# access in the pipeline graph to use this.
986-
submod = split.get_submodule(first_user.target)
987-
callee_output_nodes = [
988-
n for n in submod.graph.nodes if n.op == "output"
989-
]
990-
assert len(callee_output_nodes) == 1
991-
callee_output_node = callee_output_nodes[0]
938+
def handle_multi_use_params(
939+
split,
940+
multi_use_params_qualnames,
941+
):
942+
# TODO: do we maintain the invariant that `Node.users` is topologically ordered? I don't think so
943+
node_to_first_user: Dict[fx.Node, fx.Node] = {}
944+
for node in split.graph.nodes:
945+
for input in node.all_input_nodes:
946+
if input not in node_to_first_user:
947+
node_to_first_user[input] = node
948+
949+
for node in split.graph.nodes:
950+
if (
951+
node.op == "get_attr"
952+
and node.target in multi_use_params_qualnames
953+
):
954+
reuse_type = multi_use_params_qualnames[node.target]
955+
if reuse_type == MultiUseParameterConfig.TRANSMIT:
956+
first_user = node_to_first_user[node]
957+
assert first_user.op == "call_module"
992958

993-
# TODO: zero outputs?
994-
if isinstance(callee_output_node.args[0], tuple):
995-
new_output_args = callee_output_node.args[0] + (
996-
callee_param_def,
997-
)
998-
callee_output_node.args = (new_output_args,)
999-
new_output_idx = len(new_output_args) - 1
1000-
promoted_to_tuple = False
1001-
else:
1002-
new_output_args = (
1003-
callee_output_node.args[0],
1004-
callee_param_def,
1005-
)
1006-
callee_output_node.args = (new_output_args,)
1007-
new_output_idx = len(new_output_args) - 1
1008-
promoted_to_tuple = True
1009-
1010-
submod.graph.lint()
1011-
submod.recompile()
1012-
1013-
with split.graph.inserting_after(first_user):
1014-
if promoted_to_tuple:
1015-
# TODO: test this code path
1016-
orig_output_getitem = split.graph.call_function(
1017-
operator.getitem, (first_user, 0)
1018-
)
1019-
first_user.replace_all_uses_with(
1020-
orig_output_getitem
1021-
)
1022-
# HACK because the above replace_all_uses with ALSO replaced the instance
1023-
# of first_user within the getitem node we just added
1024-
orig_output_getitem.args = (
1025-
first_user,
1026-
) + orig_output_getitem.args[1:]
1027-
1028-
transmitted_value_getitem = split.graph.call_function(
1029-
operator.getitem, (first_user, new_output_idx)
1030-
)
1031-
node.replace_all_uses_with(transmitted_value_getitem)
1032-
split.graph.erase_node(node)
1033-
elif reuse_type == MultiUseParameterConfig.REPLICATE:
1034-
for user in copy.copy(node.users):
1035959
use_idx = delete_user_reference(
1036-
node, user, delete_node=False
960+
node, first_user, delete_node=False
1037961
)
962+
1038963
atoms = node.target.split(".")
1039964
mod_itr = split
1040965
for atom in atoms[:-1]:
1041966
mod_itr = getattr(mod_itr, atom)
1042967
param_val = getattr(mod_itr, atoms[-1])
1043968
is_buffer = atoms[-1] in mod_itr._buffers
1044969

1045-
move_param_to_callee(
1046-
split, user.target, param_val, use_idx, is_buffer
970+
callee_param_def = move_param_to_callee( # type: ignore[call-arg]
971+
split,
972+
first_user.target,
973+
param_val,
974+
use_idx,
975+
is_buffer,
1047976
)
1048977

1049-
atoms = node.target.split(".")
1050-
mod_itr = split
1051-
for atom in atoms[:-1]:
1052-
mod_itr = getattr(mod_itr, atom)
978+
delattr(mod_itr, atoms[-1])
979+
980+
# Add extra output to the callee and switch references to the parameter
981+
# access in the pipeline graph to use this.
982+
submod = split.get_submodule(first_user.target)
983+
callee_output_nodes = [
984+
n for n in submod.graph.nodes if n.op == "output"
985+
]
986+
assert len(callee_output_nodes) == 1
987+
callee_output_node = callee_output_nodes[0]
988+
989+
# TODO: zero outputs?
990+
if isinstance(callee_output_node.args[0], tuple):
991+
new_output_args = callee_output_node.args[0] + (
992+
callee_param_def,
993+
)
994+
callee_output_node.args = (new_output_args,)
995+
new_output_idx = len(new_output_args) - 1
996+
promoted_to_tuple = False
997+
else:
998+
new_output_args = (
999+
callee_output_node.args[0],
1000+
callee_param_def,
1001+
)
1002+
callee_output_node.args = (new_output_args,)
1003+
new_output_idx = len(new_output_args) - 1
1004+
promoted_to_tuple = True
1005+
1006+
submod.graph.lint()
1007+
submod.recompile()
1008+
1009+
with split.graph.inserting_after(first_user):
1010+
if promoted_to_tuple:
1011+
# TODO: test this code path
1012+
orig_output_getitem = split.graph.call_function(
1013+
operator.getitem, (first_user, 0)
1014+
)
1015+
first_user.replace_all_uses_with(
1016+
orig_output_getitem
1017+
)
1018+
# HACK because the above replace_all_uses with ALSO replaced the instance
1019+
# of first_user within the getitem node we just added
1020+
orig_output_getitem.args = (
1021+
first_user,
1022+
) + orig_output_getitem.args[1:]
1023+
1024+
transmitted_value_getitem = (
1025+
split.graph.call_function(
1026+
operator.getitem,
1027+
(first_user, new_output_idx),
1028+
)
1029+
)
1030+
node.replace_all_uses_with(
1031+
transmitted_value_getitem
1032+
)
1033+
split.graph.erase_node(node)
1034+
elif reuse_type == MultiUseParameterConfig.REPLICATE:
1035+
for user in copy.copy(node.users):
1036+
use_idx = delete_user_reference(
1037+
node, user, delete_node=False
1038+
)
1039+
atoms = node.target.split(".")
1040+
mod_itr = split
1041+
for atom in atoms[:-1]:
1042+
mod_itr = getattr(mod_itr, atom)
1043+
param_val = getattr(mod_itr, atoms[-1])
1044+
is_buffer = atoms[-1] in mod_itr._buffers
1045+
1046+
move_param_to_callee( # type: ignore[call-arg]
1047+
split,
1048+
user.target,
1049+
param_val,
1050+
use_idx,
1051+
is_buffer,
1052+
)
10531053

1054-
delattr(mod_itr, atoms[-1])
1054+
atoms = node.target.split(".")
1055+
mod_itr = split
1056+
for atom in atoms[:-1]:
1057+
mod_itr = getattr(mod_itr, atom)
10551058

1056-
split.graph.erase_node(node)
1057-
else:
1058-
raise ValueError(
1059-
f"Unknown multi-use config value {reuse_type} specified for {node.target}"
1060-
)
1061-
"""
1059+
delattr(mod_itr, atoms[-1])
1060+
1061+
split.graph.erase_node(node)
1062+
else:
1063+
raise ValueError(
1064+
f"Unknown multi-use config value {reuse_type} specified for {node.target}"
1065+
)
1066+
1067+
if len(multi_use_params_qualnames) > 0:
1068+
# TODO: re-enable support for multi-use parameters
1069+
raise NotImplementedError(
1070+
"Sharing model parameters between stages are not yet supported. "
1071+
"Found the following shared parameters in your model: "
1072+
f"{multi_use_params_qualnames}"
1073+
)
1074+
set_multi_use_param_spec(
1075+
multi_use_params_qualnames, multi_use_param_spec
1076+
)
1077+
handle_multi_use_params(split, multi_use_params_qualnames)
10621078

10631079
split.delete_all_unused_submodules()
10641080
split.graph.lint()
@@ -1285,18 +1301,6 @@ class SplitPoint(Enum):
12851301
END = 2
12861302

12871303

1288-
def _split_before_forwad(self, *args, **kwargs):
1289-
pipe_split()
1290-
return self.orig_forward(*args, **kwargs)
1291-
1292-
1293-
def _split_after_forwad(self, *args, **kwargs):
1294-
try:
1295-
return self.orig_forward(*args, **kwargs)
1296-
finally:
1297-
pipe_split()
1298-
1299-
13001304
# For backward compatibility, we kept the PipeSplitWrapper class because `class
13011305
# SplitPoint` used to be defined in this class.
13021306
class PipeSplitWrapper:

requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
torch >= 2.3.0.dev
2-
packaging >= 21.3

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def write_version_file():
4646
# If the torch version has a ".dev" suffix, it would represent a nightly version of PyTorch.
4747
# It can be installed as a binary or from source.
4848
"torch>=2.3.0.dev",
49-
"packaging>=21.3",
5049
]
5150

5251
extras: Dict = {}

0 commit comments

Comments
 (0)