|
10 | 10 |
|
11 | 11 | import torch
|
12 | 12 | import torch.fx as fx
|
13 |
| -from packaging import version |
14 | 13 | from torch.export import ExportedProgram
|
15 | 14 | from torch.fx.node import map_aggregate
|
16 | 15 | from torch.fx.passes.split_module import split_module
|
|
29 | 28 |
|
30 | 29 | logger = logging.getLogger(__name__)
|
31 | 30 |
|
32 |
| -# Because split_module with 4 arguments is available only in PT 1.12+ |
33 |
| -TORCH_FX_REQUIRED_VERSION = version.parse("1.12") |
34 |
| - |
35 |
| -torch_version = version.parse(torch.__version__) |
36 |
| -assert (torch_version.major, torch_version.minor) >= ( # type: ignore |
37 |
| - TORCH_FX_REQUIRED_VERSION.major, # type: ignore |
38 |
| - TORCH_FX_REQUIRED_VERSION.minor, # type: ignore |
39 |
| -), "PiPPy requires PyTorch >= 1.12" |
40 |
| - |
41 | 31 | # TODO:
|
42 | 32 | # 1. investigate gradient sync for shared parameters. how does DDP do it?
|
43 |
| -# 2. Modify serialization to put parameters back in their original qualname form? |
44 |
| -# 3. Shape specialized tracing? |
45 |
| -# 4. Can we define semantics for shared module call? Can we make this configurable in the same way as |
46 |
| -# with shared parameters? probably need to modify split_module in this case |
47 |
| -# 5. Add parameter movement to split_module |
| 33 | +# 2. Add parameter movement to split_module |
48 | 34 |
|
49 | 35 |
|
50 | 36 | def _find_loss_from_output_and_spec(output_val, spec_val):
|
@@ -931,134 +917,164 @@ def move_param_to_callee(
|
931 | 917 | if node.op == "get_attr" and len(node.users) > 1:
|
932 | 918 | multi_use_params_qualnames.setdefault(node.target)
|
933 | 919 |
|
934 |
| - # TODO: re-enable support for multi-use parameters |
935 |
| - assert len(multi_use_params_qualnames) == 0 |
936 |
| - |
937 |
| - """ |
938 |
| - for param in multi_use_params_qualnames: |
939 |
| - if isinstance(multi_use_param_spec, MultiUseParameterConfig): |
940 |
| - multi_use_params_qualnames[param] = multi_use_param_spec |
941 |
| - elif isinstance(multi_use_param_spec, dict): |
942 |
| - multi_use_params_qualnames[param] = multi_use_param_spec.get( |
943 |
| - param, MultiUseParameterConfig.TRANSMIT |
944 |
| - ) |
945 |
| - else: |
946 |
| - raise ValueError( |
947 |
| - "multi_use_param_spec must be MultiUseParamSpec enum or dict" |
948 |
| - ) |
949 |
| -
|
950 |
| - # TODO: do we maintain the invariant that `Node.users` is topologically ordered? I don't think so |
951 |
| - node_to_first_user: Dict[fx.Node, fx.Node] = {} |
952 |
| - for node in split.graph.nodes: |
953 |
| - for input in node.all_input_nodes: |
954 |
| - if input not in node_to_first_user: |
955 |
| - node_to_first_user[input] = node |
956 |
| -
|
957 |
| - for node in split.graph.nodes: |
958 |
| - if ( |
959 |
| - node.op == "get_attr" |
960 |
| - and node.target in multi_use_params_qualnames |
961 |
| - ): |
962 |
| - reuse_type = multi_use_params_qualnames[node.target] |
963 |
| - if reuse_type == MultiUseParameterConfig.TRANSMIT: |
964 |
| - first_user = node_to_first_user[node] |
965 |
| - assert first_user.op == "call_module" |
966 |
| -
|
967 |
| - use_idx = delete_user_reference( |
968 |
| - node, first_user, delete_node=False |
| 920 | + def set_multi_use_param_spec( |
| 921 | + multi_use_params_qualnames, |
| 922 | + multi_use_param_spec, |
| 923 | + ): |
| 924 | + for param in multi_use_params_qualnames: |
| 925 | + if isinstance(multi_use_param_spec, MultiUseParameterConfig): |
| 926 | + multi_use_params_qualnames[param] = multi_use_param_spec |
| 927 | + elif isinstance(multi_use_param_spec, dict): |
| 928 | + multi_use_params_qualnames[ |
| 929 | + param |
| 930 | + ] = multi_use_param_spec.get( |
| 931 | + param, MultiUseParameterConfig.TRANSMIT |
969 | 932 | )
|
970 |
| -
|
971 |
| - atoms = node.target.split(".") |
972 |
| - mod_itr = split |
973 |
| - for atom in atoms[:-1]: |
974 |
| - mod_itr = getattr(mod_itr, atom) |
975 |
| - param_val = getattr(mod_itr, atoms[-1]) |
976 |
| - is_buffer = atoms[-1] in mod_itr._buffers |
977 |
| -
|
978 |
| - callee_param_def = move_param_to_callee( |
979 |
| - split, first_user.target, param_val, use_idx, is_buffer |
| 933 | + else: |
| 934 | + raise ValueError( |
| 935 | + "multi_use_param_spec must be MultiUseParamSpec enum or dict" |
980 | 936 | )
|
981 | 937 |
|
982 |
| - delattr(mod_itr, atoms[-1]) |
983 |
| -
|
984 |
| - # Add extra output to the callee and switch references to the parameter |
985 |
| - # access in the pipeline graph to use this. |
986 |
| - submod = split.get_submodule(first_user.target) |
987 |
| - callee_output_nodes = [ |
988 |
| - n for n in submod.graph.nodes if n.op == "output" |
989 |
| - ] |
990 |
| - assert len(callee_output_nodes) == 1 |
991 |
| - callee_output_node = callee_output_nodes[0] |
| 938 | + def handle_multi_use_params( |
| 939 | + split, |
| 940 | + multi_use_params_qualnames, |
| 941 | + ): |
| 942 | + # TODO: do we maintain the invariant that `Node.users` is topologically ordered? I don't think so |
| 943 | + node_to_first_user: Dict[fx.Node, fx.Node] = {} |
| 944 | + for node in split.graph.nodes: |
| 945 | + for input in node.all_input_nodes: |
| 946 | + if input not in node_to_first_user: |
| 947 | + node_to_first_user[input] = node |
| 948 | + |
| 949 | + for node in split.graph.nodes: |
| 950 | + if ( |
| 951 | + node.op == "get_attr" |
| 952 | + and node.target in multi_use_params_qualnames |
| 953 | + ): |
| 954 | + reuse_type = multi_use_params_qualnames[node.target] |
| 955 | + if reuse_type == MultiUseParameterConfig.TRANSMIT: |
| 956 | + first_user = node_to_first_user[node] |
| 957 | + assert first_user.op == "call_module" |
992 | 958 |
|
993 |
| - # TODO: zero outputs? |
994 |
| - if isinstance(callee_output_node.args[0], tuple): |
995 |
| - new_output_args = callee_output_node.args[0] + ( |
996 |
| - callee_param_def, |
997 |
| - ) |
998 |
| - callee_output_node.args = (new_output_args,) |
999 |
| - new_output_idx = len(new_output_args) - 1 |
1000 |
| - promoted_to_tuple = False |
1001 |
| - else: |
1002 |
| - new_output_args = ( |
1003 |
| - callee_output_node.args[0], |
1004 |
| - callee_param_def, |
1005 |
| - ) |
1006 |
| - callee_output_node.args = (new_output_args,) |
1007 |
| - new_output_idx = len(new_output_args) - 1 |
1008 |
| - promoted_to_tuple = True |
1009 |
| -
|
1010 |
| - submod.graph.lint() |
1011 |
| - submod.recompile() |
1012 |
| -
|
1013 |
| - with split.graph.inserting_after(first_user): |
1014 |
| - if promoted_to_tuple: |
1015 |
| - # TODO: test this code path |
1016 |
| - orig_output_getitem = split.graph.call_function( |
1017 |
| - operator.getitem, (first_user, 0) |
1018 |
| - ) |
1019 |
| - first_user.replace_all_uses_with( |
1020 |
| - orig_output_getitem |
1021 |
| - ) |
1022 |
| - # HACK because the above replace_all_uses with ALSO replaced the instance |
1023 |
| - # of first_user within the getitem node we just added |
1024 |
| - orig_output_getitem.args = ( |
1025 |
| - first_user, |
1026 |
| - ) + orig_output_getitem.args[1:] |
1027 |
| -
|
1028 |
| - transmitted_value_getitem = split.graph.call_function( |
1029 |
| - operator.getitem, (first_user, new_output_idx) |
1030 |
| - ) |
1031 |
| - node.replace_all_uses_with(transmitted_value_getitem) |
1032 |
| - split.graph.erase_node(node) |
1033 |
| - elif reuse_type == MultiUseParameterConfig.REPLICATE: |
1034 |
| - for user in copy.copy(node.users): |
1035 | 959 | use_idx = delete_user_reference(
|
1036 |
| - node, user, delete_node=False |
| 960 | + node, first_user, delete_node=False |
1037 | 961 | )
|
| 962 | + |
1038 | 963 | atoms = node.target.split(".")
|
1039 | 964 | mod_itr = split
|
1040 | 965 | for atom in atoms[:-1]:
|
1041 | 966 | mod_itr = getattr(mod_itr, atom)
|
1042 | 967 | param_val = getattr(mod_itr, atoms[-1])
|
1043 | 968 | is_buffer = atoms[-1] in mod_itr._buffers
|
1044 | 969 |
|
1045 |
| - move_param_to_callee( |
1046 |
| - split, user.target, param_val, use_idx, is_buffer |
| 970 | + callee_param_def = move_param_to_callee( # type: ignore[call-arg] |
| 971 | + split, |
| 972 | + first_user.target, |
| 973 | + param_val, |
| 974 | + use_idx, |
| 975 | + is_buffer, |
1047 | 976 | )
|
1048 | 977 |
|
1049 |
| - atoms = node.target.split(".") |
1050 |
| - mod_itr = split |
1051 |
| - for atom in atoms[:-1]: |
1052 |
| - mod_itr = getattr(mod_itr, atom) |
| 978 | + delattr(mod_itr, atoms[-1]) |
| 979 | + |
| 980 | + # Add extra output to the callee and switch references to the parameter |
| 981 | + # access in the pipeline graph to use this. |
| 982 | + submod = split.get_submodule(first_user.target) |
| 983 | + callee_output_nodes = [ |
| 984 | + n for n in submod.graph.nodes if n.op == "output" |
| 985 | + ] |
| 986 | + assert len(callee_output_nodes) == 1 |
| 987 | + callee_output_node = callee_output_nodes[0] |
| 988 | + |
| 989 | + # TODO: zero outputs? |
| 990 | + if isinstance(callee_output_node.args[0], tuple): |
| 991 | + new_output_args = callee_output_node.args[0] + ( |
| 992 | + callee_param_def, |
| 993 | + ) |
| 994 | + callee_output_node.args = (new_output_args,) |
| 995 | + new_output_idx = len(new_output_args) - 1 |
| 996 | + promoted_to_tuple = False |
| 997 | + else: |
| 998 | + new_output_args = ( |
| 999 | + callee_output_node.args[0], |
| 1000 | + callee_param_def, |
| 1001 | + ) |
| 1002 | + callee_output_node.args = (new_output_args,) |
| 1003 | + new_output_idx = len(new_output_args) - 1 |
| 1004 | + promoted_to_tuple = True |
| 1005 | + |
| 1006 | + submod.graph.lint() |
| 1007 | + submod.recompile() |
| 1008 | + |
| 1009 | + with split.graph.inserting_after(first_user): |
| 1010 | + if promoted_to_tuple: |
| 1011 | + # TODO: test this code path |
| 1012 | + orig_output_getitem = split.graph.call_function( |
| 1013 | + operator.getitem, (first_user, 0) |
| 1014 | + ) |
| 1015 | + first_user.replace_all_uses_with( |
| 1016 | + orig_output_getitem |
| 1017 | + ) |
| 1018 | + # HACK because the above replace_all_uses with ALSO replaced the instance |
| 1019 | + # of first_user within the getitem node we just added |
| 1020 | + orig_output_getitem.args = ( |
| 1021 | + first_user, |
| 1022 | + ) + orig_output_getitem.args[1:] |
| 1023 | + |
| 1024 | + transmitted_value_getitem = ( |
| 1025 | + split.graph.call_function( |
| 1026 | + operator.getitem, |
| 1027 | + (first_user, new_output_idx), |
| 1028 | + ) |
| 1029 | + ) |
| 1030 | + node.replace_all_uses_with( |
| 1031 | + transmitted_value_getitem |
| 1032 | + ) |
| 1033 | + split.graph.erase_node(node) |
| 1034 | + elif reuse_type == MultiUseParameterConfig.REPLICATE: |
| 1035 | + for user in copy.copy(node.users): |
| 1036 | + use_idx = delete_user_reference( |
| 1037 | + node, user, delete_node=False |
| 1038 | + ) |
| 1039 | + atoms = node.target.split(".") |
| 1040 | + mod_itr = split |
| 1041 | + for atom in atoms[:-1]: |
| 1042 | + mod_itr = getattr(mod_itr, atom) |
| 1043 | + param_val = getattr(mod_itr, atoms[-1]) |
| 1044 | + is_buffer = atoms[-1] in mod_itr._buffers |
| 1045 | + |
| 1046 | + move_param_to_callee( # type: ignore[call-arg] |
| 1047 | + split, |
| 1048 | + user.target, |
| 1049 | + param_val, |
| 1050 | + use_idx, |
| 1051 | + is_buffer, |
| 1052 | + ) |
1053 | 1053 |
|
1054 |
| - delattr(mod_itr, atoms[-1]) |
| 1054 | + atoms = node.target.split(".") |
| 1055 | + mod_itr = split |
| 1056 | + for atom in atoms[:-1]: |
| 1057 | + mod_itr = getattr(mod_itr, atom) |
1055 | 1058 |
|
1056 |
| - split.graph.erase_node(node) |
1057 |
| - else: |
1058 |
| - raise ValueError( |
1059 |
| - f"Unknown multi-use config value {reuse_type} specified for {node.target}" |
1060 |
| - ) |
1061 |
| - """ |
| 1059 | + delattr(mod_itr, atoms[-1]) |
| 1060 | + |
| 1061 | + split.graph.erase_node(node) |
| 1062 | + else: |
| 1063 | + raise ValueError( |
| 1064 | + f"Unknown multi-use config value {reuse_type} specified for {node.target}" |
| 1065 | + ) |
| 1066 | + |
| 1067 | + if len(multi_use_params_qualnames) > 0: |
| 1068 | + # TODO: re-enable support for multi-use parameters |
| 1069 | + raise NotImplementedError( |
| 1070 | + "Sharing model parameters between stages are not yet supported. " |
| 1071 | + "Found the following shared parameters in your model: " |
| 1072 | + f"{multi_use_params_qualnames}" |
| 1073 | + ) |
| 1074 | + set_multi_use_param_spec( |
| 1075 | + multi_use_params_qualnames, multi_use_param_spec |
| 1076 | + ) |
| 1077 | + handle_multi_use_params(split, multi_use_params_qualnames) |
1062 | 1078 |
|
1063 | 1079 | split.delete_all_unused_submodules()
|
1064 | 1080 | split.graph.lint()
|
@@ -1285,18 +1301,6 @@ class SplitPoint(Enum):
|
1285 | 1301 | END = 2
|
1286 | 1302 |
|
1287 | 1303 |
|
1288 |
| -def _split_before_forwad(self, *args, **kwargs): |
1289 |
| - pipe_split() |
1290 |
| - return self.orig_forward(*args, **kwargs) |
1291 |
| - |
1292 |
| - |
1293 |
| -def _split_after_forwad(self, *args, **kwargs): |
1294 |
| - try: |
1295 |
| - return self.orig_forward(*args, **kwargs) |
1296 |
| - finally: |
1297 |
| - pipe_split() |
1298 |
| - |
1299 |
| - |
1300 | 1304 | # For backward compatibility, we kept the PipeSplitWrapper class because `class
|
1301 | 1305 | # SplitPoint` used to be defined in this class.
|
1302 | 1306 | class PipeSplitWrapper:
|
|
0 commit comments