99
1010import torch
1111import torch .distributed as dist
12- import transformer_engine as te
12+
13+ try :
14+ import transformer_engine as te
15+
16+ has_te = True
17+ except (ModuleNotFoundError , ImportError ):
18+ has_te = False
19+
1320from torch import nn
1421
1522from internlm .accelerator import get_accelerator
@@ -1011,161 +1018,163 @@ def __init__(
10111018 self .full_weight_shape = torch .Size ((num_groups , in_features , out_features ))
10121019
10131020
1014- class TEColumnParallelLinear (te .pytorch .Linear ):
1015- """
1016- Wrapper for the Transformer-Engine's `Linear` layer.
1017- """
1021+ if has_te :
10181022
1019- def __init__ (
1020- self ,
1021- in_features : int ,
1022- out_features : int ,
1023- bias : bool ,
1024- skip_bias_add : bool ,
1025- is_expert : bool ,
1026- tp_comm_buffer_name : str = None ,
1027- ):
1028- if is_expert :
1029- raise ValueError ("Transformer Engine linear layers do not yet support MoE" )
1030-
1031- # TE returns a zero length Tensor when bias=False and
1032- # return_bias=True, but we prefer None. So in that case we
1033- # tell TE to not return the bias, and return None
1034- # ourselves. This way our forward always returns two values
1035- # and we don't have to deal with the zero length Tensor.
1036- self .te_return_bias = skip_bias_add and bias
1037- self .is_first_microbatch = True
1038-
1039- extra_kwargs = {"params_dtype" : gpc .config .model .dtype }
1040- if is_te_min_version ("0.12.0" ):
1041- extra_kwargs ["device" ] = torch .cuda .current_device ()
1042-
1043- if gpc .config .parallel ["tensor" ]["tp_overlap" ]:
1044- extra_kwargs ["ub_bulk_wgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1045- "tp_comm_bulk_wgrad" , True
1046- )
1047- extra_kwargs ["ub_bulk_dgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1048- "tp_comm_bulk_dgrad" , True
1049- )
1050- if is_te_min_version ("1.5.0" , check_equality = False ):
1051- extra_kwargs ["ub_overlap_ag" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1052- "tp_comm_overlap_ag" , True
1053- )
1054- else :
1055- raise NotImplementedError ("tp overlap is supported only when transformer_engine version >= 1.5.0" )
1056- assert (
1057- tp_comm_buffer_name is not None
1058- ), "Buffer name should be set to configure communication overlap settings"
1059- extra_kwargs ["ub_name" ] = tp_comm_buffer_name
1060-
1061- parallel_mode = get_tensor_split_parallel_mode (is_expert = is_expert )
1062- tp_size = gpc .get_world_size (parallel_mode )
1063- tp_group = gpc .get_group (parallel_mode )
1064- super ().__init__ (
1065- in_features = in_features ,
1066- out_features = out_features ,
1067- sequence_parallel = gpc .config .parallel .sequence_parallel ,
1068- tp_group = tp_group ,
1069- tp_size = tp_size ,
1070- bias = bias ,
1071- return_bias = self .te_return_bias ,
1072- parallel_mode = "column" ,
1073- ** extra_kwargs ,
1074- )
1075-
1076- def forward (self , x ):
1077- """Forward."""
1078- _is_first_microbatch = self .is_first_microbatch
1079- x = x .transpose (0 , 1 )
1080- out = super ().forward (x , is_first_microbatch = _is_first_microbatch )
1081- out = out .transpose (0 , 1 )
1082-
1083- self .is_first_microbatch = False
1023+ class TEColumnParallelLinear (te .pytorch .Linear ):
1024+ """
1025+ Wrapper for the Transformer-Engine's `Linear` layer.
1026+ """
10841027
1085- return out
1028+ def __init__ (
1029+ self ,
1030+ in_features : int ,
1031+ out_features : int ,
1032+ bias : bool ,
1033+ skip_bias_add : bool ,
1034+ is_expert : bool ,
1035+ tp_comm_buffer_name : str = None ,
1036+ ):
1037+ if is_expert :
1038+ raise ValueError ("Transformer Engine linear layers do not yet support MoE" )
1039+
1040+ # TE returns a zero length Tensor when bias=False and
1041+ # return_bias=True, but we prefer None. So in that case we
1042+ # tell TE to not return the bias, and return None
1043+ # ourselves. This way our forward always returns two values
1044+ # and we don't have to deal with the zero length Tensor.
1045+ self .te_return_bias = skip_bias_add and bias
1046+ self .is_first_microbatch = True
1047+
1048+ extra_kwargs = {"params_dtype" : gpc .config .model .dtype }
1049+ if is_te_min_version ("0.12.0" ):
1050+ extra_kwargs ["device" ] = torch .cuda .current_device ()
1051+
1052+ if gpc .config .parallel ["tensor" ].get ("tp_overlap" , False ):
1053+ extra_kwargs ["ub_bulk_wgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1054+ "tp_comm_bulk_wgrad" , True
1055+ )
1056+ extra_kwargs ["ub_bulk_dgrad" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1057+ "tp_comm_bulk_dgrad" , True
1058+ )
1059+ if is_te_min_version ("1.5.0" , check_equality = False ):
1060+ extra_kwargs ["ub_overlap_ag" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1061+ "tp_comm_overlap_ag" , True
1062+ )
1063+ else :
1064+ raise NotImplementedError ("tp overlap is supported only when transformer_engine version >= 1.5.0" )
1065+ assert (
1066+ tp_comm_buffer_name is not None
1067+ ), "Buffer name should be set to configure communication overlap settings"
1068+ extra_kwargs ["ub_name" ] = tp_comm_buffer_name
1069+
1070+ parallel_mode = get_tensor_split_parallel_mode (is_expert = is_expert )
1071+ tp_size = gpc .get_world_size (parallel_mode )
1072+ tp_group = gpc .get_group (parallel_mode )
1073+ super ().__init__ (
1074+ in_features = in_features ,
1075+ out_features = out_features ,
1076+ sequence_parallel = gpc .config .parallel .sequence_parallel ,
1077+ tp_group = tp_group ,
1078+ tp_size = tp_size ,
1079+ bias = bias ,
1080+ return_bias = self .te_return_bias ,
1081+ parallel_mode = "column" ,
1082+ ** extra_kwargs ,
1083+ )
10861084
1085+ def forward (self , x ):
1086+ """Forward."""
1087+ _is_first_microbatch = self .is_first_microbatch
1088+ x = x .transpose (0 , 1 )
1089+ out = super ().forward (x , is_first_microbatch = _is_first_microbatch )
1090+ out = out .transpose (0 , 1 )
10871091
1088- class TERowParallelLinear (te .pytorch .Linear ):
1089- """
1090- Wrapper for the Transformer-Engine's `Linear` layer.
1091- """
1092+ self .is_first_microbatch = False
10921093
1093- def __init__ (
1094- self ,
1095- in_features : int ,
1096- out_features : int ,
1097- bias : bool ,
1098- skip_bias_add : bool ,
1099- is_expert : bool = False ,
1100- tp_comm_buffer_name : str = None ,
1101- ):
1102- # TE returns a zero length Tensor when bias=False and
1103- # return_bias=True. Here we need a single Tensor
1104- self .te_return_bias = skip_bias_add and bias
1105- self .is_first_microbatch = True
1106-
1107- extra_kwargs = {"params_dtype" : gpc .config .model .dtype }
1108- if is_te_min_version ("0.12.0" ):
1109- extra_kwargs ["device" ] = torch .cuda .current_device ()
1110-
1111- if gpc .config .parallel ["tensor" ]["tp_overlap" ]:
1112- if is_te_min_version ("1.5.0" ):
1113- extra_kwargs ["ub_overlap_ag" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1114- "tp_comm_overlap_ag" , True
1115- )
1116- extra_kwargs ["ub_overlap_rs" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1117- "tp_comm_overlap_rs" , True
1118- )
1119- # Disable ub overlap for experts.
1120- if is_expert :
1121- extra_kwargs ["ub_overlap_ag" ] = False
1122- extra_kwargs ["ub_overlap_rs" ] = False
1123- else :
1124- raise NotImplementedError ("tp overlap is supported only when transformer_engine version >= 1.5.0" )
1125- assert (
1126- tp_comm_buffer_name is not None
1127- ), "Buffer name should be set to configure communication overlap settings"
1128- extra_kwargs ["ub_name" ] = tp_comm_buffer_name
1094+ return out
11291095
1130- self .expert_parallel = gpc .config .parallel ["expert" ].get ("size" , 1 ) > 1
1131- parallel_mode = get_tensor_split_parallel_mode (is_expert = is_expert )
1132- # Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
1133- tp_size = gpc .get_world_size (parallel_mode )
1134- tp_group = gpc .get_group (parallel_mode )
1135- explicit_expert_comm = is_expert and (tp_size > 1 or self .expert_parallel )
1136-
1137- split_mode = "row"
1138- if explicit_expert_comm :
1139- assert in_features % tp_size == 0 , "{} is not divisible by {}" .format (in_features , tp_size )
1140- in_features = in_features // tp_size
1141- split_mode = None
1142- tp_size = 1
1143- tp_group = None
1096+ class TERowParallelLinear (te .pytorch .Linear ):
1097+ """
1098+ Wrapper for the Transformer-Engine's `Linear` layer.
1099+ """
11441100
1145- super ().__init__ (
1146- in_features = in_features ,
1147- out_features = out_features ,
1148- sequence_parallel = gpc .config .parallel .sequence_parallel ,
1149- tp_group = tp_group ,
1150- tp_size = tp_size ,
1151- bias = bias ,
1152- return_bias = self .te_return_bias ,
1153- parallel_mode = split_mode ,
1154- ** extra_kwargs ,
1155- )
1101+ def __init__ (
1102+ self ,
1103+ in_features : int ,
1104+ out_features : int ,
1105+ bias : bool ,
1106+ skip_bias_add : bool ,
1107+ is_expert : bool = False ,
1108+ tp_comm_buffer_name : str = None ,
1109+ ):
1110+ # TE returns a zero length Tensor when bias=False and
1111+ # return_bias=True. Here we need a single Tensor
1112+ self .te_return_bias = skip_bias_add and bias
1113+ self .is_first_microbatch = True
1114+
1115+ extra_kwargs = {"params_dtype" : gpc .config .model .dtype }
1116+ if is_te_min_version ("0.12.0" ):
1117+ extra_kwargs ["device" ] = torch .cuda .current_device ()
1118+
1119+ if gpc .config .parallel ["tensor" ].get ("tp_overlap" , False ):
1120+ if is_te_min_version ("1.5.0" ):
1121+ extra_kwargs ["ub_overlap_ag" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1122+ "tp_comm_overlap_ag" , True
1123+ )
1124+ extra_kwargs ["ub_overlap_rs" ] = gpc .config .parallel ["tensor" ]["tp_overlap_cfg" ].get (
1125+ "tp_comm_overlap_rs" , True
1126+ )
1127+ # Disable ub overlap for experts.
1128+ if is_expert :
1129+ extra_kwargs ["ub_overlap_ag" ] = False
1130+ extra_kwargs ["ub_overlap_rs" ] = False
1131+ else :
1132+ raise NotImplementedError ("tp overlap is supported only when transformer_engine version >= 1.5.0" )
1133+ assert (
1134+ tp_comm_buffer_name is not None
1135+ ), "Buffer name should be set to configure communication overlap settings"
1136+ extra_kwargs ["ub_name" ] = tp_comm_buffer_name
1137+
1138+ self .expert_parallel = gpc .config .parallel ["expert" ].get ("size" , 1 ) > 1
1139+ parallel_mode = get_tensor_split_parallel_mode (is_expert = is_expert )
1140+ # Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
1141+ tp_size = gpc .get_world_size (parallel_mode )
1142+ tp_group = gpc .get_group (parallel_mode )
1143+ explicit_expert_comm = is_expert and (tp_size > 1 or self .expert_parallel )
1144+
1145+ split_mode = "row"
1146+ if explicit_expert_comm :
1147+ assert in_features % tp_size == 0 , "{} is not divisible by {}" .format (in_features , tp_size )
1148+ in_features = in_features // tp_size
1149+ split_mode = None
1150+ tp_size = 1
1151+ tp_group = None
1152+
1153+ super ().__init__ (
1154+ in_features = in_features ,
1155+ out_features = out_features ,
1156+ sequence_parallel = gpc .config .parallel .sequence_parallel ,
1157+ tp_group = tp_group ,
1158+ tp_size = tp_size ,
1159+ bias = bias ,
1160+ return_bias = self .te_return_bias ,
1161+ parallel_mode = split_mode ,
1162+ ** extra_kwargs ,
1163+ )
11561164
1157- for param in self .parameters ():
1158- setattr (param , "allreduce" , not (is_expert and self .expert_parallel ))
1165+ def forward (self , x ):
1166+ """Forward."""
1167+ _is_first_microbatch = self .is_first_microbatch
1168+ x = x .transpose (0 , 1 )
1169+ out = super ().forward (x , is_first_microbatch = _is_first_microbatch )
1170+ out = out .transpose (0 , 1 )
1171+ self .is_first_microbatch = False
11591172
1160- def forward (self , x ):
1161- """Forward."""
1162- _is_first_microbatch = self .is_first_microbatch
1163- x = x .transpose (0 , 1 )
1164- out = super ().forward (x , is_first_microbatch = _is_first_microbatch )
1165- out = out .transpose (0 , 1 )
1166- self .is_first_microbatch = False
1173+ return out
11671174
1168- return out
1175+ else :
1176+ TEColumnParallelLinear = ColumnParallelLinear
1177+ TERowParallelLinear = RowParallelLinear
11691178
11701179
11711180def new_linear (
@@ -1217,7 +1226,7 @@ def new_linear(
12171226 weight_scale = weight_scale ,
12181227 norm_head = norm_head ,
12191228 )
1220- elif split_mode == "column" :
1229+ elif split_mode == "column" or ( split_mode == "tecolumn" and not has_te ) :
12211230 return ColumnParallelLinear (
12221231 in_features ,
12231232 out_features ,
@@ -1236,7 +1245,7 @@ def new_linear(
12361245 is_expert ,
12371246 tp_comm_buffer_name ,
12381247 )
1239- elif split_mode == "row" :
1248+ elif split_mode == "row" or ( split_mode == "terow" and not has_te ) :
12401249 return RowParallelLinear (
12411250 in_features ,
12421251 out_features ,
0 commit comments