Skip to content

Commit 5ad2eb0

Browse files
fix(pp): fix pp get tensor shape err and layernorm input dtype err (#378)
1 parent ae2243c commit 5ad2eb0

File tree

6 files changed

+13
-9
lines changed

6 files changed

+13
-9
lines changed

internlm/core/scheduler/pipeline_scheduler_1f1b.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def get_tensor_shape():
3535
if not gpc.is_initialized(ParallelMode.PIPELINE):
3636
return None
3737

38-
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
38+
if (
39+
hasattr(gpc.config.data, "seq_len")
40+
and hasattr(gpc.config.data, "micro_bsz")
41+
and hasattr(gpc.config.model, "hidden_size")
42+
):
3943
if gpc.config.data.use_packed_dataset and gpc.is_evaluating is False:
4044
if gpc.config.parallel.sequence_parallel:
4145
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)

internlm/model/modeling_internlm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _forward(self, hidden_states, *args, **kwargs):
195195
def _dropout_and_norm_attn(_hidden_states):
196196
_dropped = self.dropout1(_hidden_states)
197197
_residual = _dropped
198-
_hidden_states = self.norm1(_residual.float())
198+
_hidden_states = self.norm1(_residual.to(self.norm1.weight.dtype))
199199
return _residual, _hidden_states
200200

201201
if self.dropout_selective_checkpoint:
@@ -212,7 +212,7 @@ def _dropout_and_norm_attn(_hidden_states):
212212
def _dropout_and_norm_ffn(_residual, _hidden_states):
213213
_dropped = self.dropout2(_hidden_states)
214214
_residual = (_dropped + _residual) if _residual is not None else _dropped
215-
_hidden_states = self.norm2(_residual.float())
215+
_hidden_states = self.norm2(_residual.to(self.norm2.weight.dtype))
216216
return _residual, _hidden_states
217217

218218
if self.dropout_selective_checkpoint:

internlm/model/modeling_internlm2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def _dropout_and_norm_attn(_residual, _hidden_states):
257257
def _dropout_and_norm_ffn(_residual, _hidden_states):
258258
_dropped = self.dropout2(_hidden_states)
259259
_residual = (_dropped + _residual) if _residual is not None else _dropped
260-
_hidden_states = self.ffn_norm(_residual.to(torch.float32))
260+
_hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype))
261261

262262
return _residual, _hidden_states
263263

internlm/model/modeling_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def _dropout_and_norm_attn(_residual, _hidden_states):
246246
def _dropout_and_norm_ffn(_residual, _hidden_states):
247247
_dropped = self.dropout2(_hidden_states)
248248
_residual = (_dropped + _residual) if _residual is not None else _dropped
249-
_hidden_states = self.ffn_norm(_residual.to(torch.float32))
249+
_hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype))
250250

251251
return _residual, _hidden_states
252252

internlm/model/modeling_mixtral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def _forward(self, hidden_states, *args, **kwargs):
214214
def _dropout_and_norm_attn(_hidden_states):
215215
_dropped = self.dropout1(_hidden_states)
216216
_residual = _dropped
217-
_hidden_states = self.norm1(_residual.float())
217+
_hidden_states = self.norm1(_residual.to(self.norm1.weight.dtype))
218218
return _residual, _hidden_states
219219

220220
if self.dropout_selective_checkpoint:
@@ -231,7 +231,7 @@ def _dropout_and_norm_attn(_hidden_states):
231231
def _dropout_and_norm_ffn(_residual, _hidden_states):
232232
_dropped = self.dropout2(_hidden_states)
233233
_residual = (_dropped + _residual) if _residual is not None else _dropped
234-
_hidden_states = self.norm2(_residual.float())
234+
_hidden_states = self.norm2(_residual.to(self.norm2.weight.dtype))
235235
return _residual, _hidden_states
236236

237237
if self.dropout_selective_checkpoint:

internlm/model/modeling_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _forward(self, hidden_states, *args, **kwargs):
205205
def _dropout_and_norm_attn(_hidden_states):
206206
_dropped = self.dropout1(_hidden_states)
207207
_residual = _dropped
208-
_hidden_states = self.norm1(_residual.float())
208+
_hidden_states = self.norm1(_residual.to(self.norm1.weight.dtype))
209209
return _residual, _hidden_states
210210

211211
if self.dropout_selective_checkpoint:
@@ -222,7 +222,7 @@ def _dropout_and_norm_attn(_hidden_states):
222222
def _dropout_and_norm_ffn(_residual, _hidden_states):
223223
_dropped = self.dropout2(_hidden_states)
224224
_residual = (_dropped + _residual) if _residual is not None else _dropped
225-
_hidden_states = self.norm2(_residual.float())
225+
_hidden_states = self.norm2(_residual.to(self.norm2.weight.dtype))
226226
return _residual, _hidden_states
227227

228228
if self.dropout_selective_checkpoint:

0 commit comments

Comments
 (0)