Skip to content

Commit 141e9eb

Browse files
feat(loss)/add different operator types for cross_entropy (#386)
1 parent 0ec6cdc commit 141e9eb

22 files changed

+682
-377
lines changed

configs/7B_MoE4_sft.py

+14
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,20 @@
103103
clip_grad_norm=1.0,
104104
)
105105

106+
107+
# loss config (dict):
108+
# 1. label_smoothing
109+
# 2. op_type: cross_entropy operator type, we support five types for loss computing,
110+
# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"]
111+
# default is "py_vocab_parallel".
112+
# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss
113+
# "apex_naive": cross_entropy from apex
114+
# "py_naive": self-implemented cross_entropy
115+
# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn
116+
# "py_vocab_parallel": self-implemented vocab parallel cross_entropy
117+
# * op_types that ends with "naive" only support parallel_output=False;
118+
# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported.
119+
106120
loss = dict(
107121
label_smoothing=0,
108122
moe_loss_coeff=0.1,

configs/7B_internlm2.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,21 @@
9898
clip_grad_norm=1.0,
9999
)
100100

101-
loss = dict(
102-
label_smoothing=0,
103-
)
101+
102+
# loss config (dict):
103+
# 1. label_smoothing
104+
# 2. op_type: cross_entropy operator type, we support five types for loss computing,
105+
# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"]
106+
# default is "py_vocab_parallel".
107+
# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss
108+
# "apex_naive": cross_entropy from apex
109+
# "py_naive": self-implemented cross_entropy
110+
# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn
111+
# "py_vocab_parallel": self-implemented vocab parallel cross_entropy
112+
113+
# * op_types that ends with "naive" only support parallel_output=False;
114+
# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported.
115+
loss = dict(label_smoothing=0, op_type="py_vocab_parallel")
104116

105117
adam = dict(
106118
lr=1e-4,

configs/7B_isp_sft.py

+16
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,24 @@
108108
clip_grad_norm=1.0,
109109
)
110110

111+
112+
# loss config (dict):
113+
# 1. label_smoothing
114+
# 2. op_type: cross_entropy operator type, we support five types for loss computing,
115+
# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"]
116+
# default is "py_vocab_parallel".
117+
# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss
118+
# "apex_naive": cross_entropy from apex
119+
# "py_naive": self-implemented cross_entropy
120+
# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn
121+
# "py_vocab_parallel": self-implemented vocab parallel cross_entropy
122+
123+
# * op_types that ends with "naive" only support parallel_output=False;
124+
# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported.
125+
111126
loss = dict(
112127
label_smoothing=0,
128+
op_type="flash_vocab_parallel",
113129
)
114130

115131
adam = dict(

internlm/core/trainer_builder.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from internlm.data.train_state import get_train_state
1717
from internlm.eval.evaluation import evaluate_on_val_dls
1818
from internlm.initialize.initialize_trainer import initialize_trainer
19-
from internlm.model.losses.ce_loss import FlashGPTLMLoss
19+
from internlm.model.losses.ce_loss import InternLoss
2020
from internlm.model.metrics import AccPerplex
2121
from internlm.monitor.monitor import send_alert_message
2222
from internlm.train.pipeline import (
@@ -172,9 +172,11 @@ def _read_config(self, config_path: str) -> list:
172172
with open(config_path, "r") as f:
173173
return f.readlines()
174174

175-
def _initialize_criterion(self) -> FlashGPTLMLoss:
176-
return FlashGPTLMLoss(
177-
parallel_output=gpc.config.model.parallel_output, label_smoothing=gpc.config.loss.label_smoothing
175+
def _initialize_criterion(self) -> InternLoss:
176+
return InternLoss(
177+
parallel_output=gpc.config.model.parallel_output,
178+
label_smoothing=gpc.config.loss.label_smoothing,
179+
op_type=gpc.config.loss.op_type,
178180
)
179181

180182
def _initialize_checkpoint_manager(

internlm/initialize/launch.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -351,17 +351,6 @@ def args_sanity_check():
351351
if "use_flash_attn" not in gpc.config.model:
352352
gpc.config.model._add_item("use_flash_attn", True)
353353

354-
old_parallel_output = gpc.config.model.get("parallel_output", None)
355-
# Try to change user setting
356-
if internlm_accelerator.get_accelerator_backend() is not AcceleratorType.GPU:
357-
gpc.config.model.update({"parallel_output": False})
358-
if old_parallel_output is True and gpc.is_rank_for_log():
359-
logger.warning(
360-
"'parallel_output' is converted from 'True' to 'False'."
361-
"Because 'parallel_output' only support by FlashCrossEntropyLoss."
362-
"Please make sure you are using flash attention in cuda device."
363-
)
364-
365354
if "MoE" in gpc.config.get("model_type", ModelType.INTERNLM.name):
366355
if "num_experts" not in model:
367356
model._add_item("num_experts", 1)
@@ -449,6 +438,9 @@ def args_sanity_check():
449438
]:
450439
gpc.config.parallel.sequence_parallel = True
451440

441+
if gpc.config.model.get("parallel_output", False) is False:
442+
logger.warning("When enable sequence parallel, it recommend to enable parallel_output")
443+
452444
# set default value for weight parallel
453445
if gpc.config.parallel["weight"].get("overlap", None) is None:
454446
gpc.config.parallel["weight"]["overlap"] = False
@@ -583,6 +575,11 @@ def args_sanity_check():
583575
gpc.config.data.use_packed_dataset is False
584576
), "only unpacked data is supported when using 2D sequence parallel."
585577

578+
# loss operator type
579+
loss_cfg = gpc.config.loss
580+
if loss_cfg.get("op_type", None) is None:
581+
loss_cfg._add_item("op_type", "py_vocab_parallel")
582+
586583

587584
def launch(
588585
config: Union[str, Path, Config, Dict],

internlm/model/losses/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from .ce_loss import FlashGPTLMLoss
1+
from .ce_loss import InternLoss
22

33
__all__ = [
4-
"FlashGPTLMLoss",
4+
"InternLoss",
55
]

internlm/model/losses/ce_loss.py

+53-19
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,61 @@
1-
#!/usr/bin/env python
2-
# -*- encoding: utf-8 -*-
3-
1+
import torch
42
from torch import nn
53

6-
from internlm.core.context import global_context as gpc
4+
from internlm.accelerator import get_accelerator
75
from internlm.model.ops.cross_entropy import new_cross_entropy
8-
from internlm.utils.logger import get_logger
96

10-
logger = get_logger(__file__)
7+
internlm_accelerator = get_accelerator()
118

129

13-
class FlashGPTLMLoss(nn.Module):
14-
"""
15-
Loss function for flash GPT Language Model.
10+
class InternLoss(nn.Module):
11+
"""We use a base class to wrap different CrossEntropy implementations
12+
and unify input and output parameters.
13+
14+
This class is designed not to rely on gpc, making it easy to transplant.
15+
16+
Different variants of CrossEntropy, with supporting parallel computation and inplace operations.
17+
18+
If parallel_output is False, the output will gather head's output, only 'FlashCrossEntropyLoss' and
19+
'CrossEntropyApexVocabParallel' support it.
1620
"""
1721

18-
def __init__(self, parallel_output=True, label_smoothing=0):
22+
def __init__(
23+
self,
24+
parallel_output=False,
25+
ignore_index=-100,
26+
reduction="mean",
27+
label_smoothing=0.0,
28+
inplace_backward=True,
29+
op_type="py_vocab_parallel",
30+
) -> None:
1931
super().__init__()
2032

2133
if label_smoothing is not None:
2234
if label_smoothing != 0:
23-
if gpc.is_rank_for_log():
24-
print(f"use label_smoothing: {label_smoothing}")
35+
print(f"use label_smoothing: {label_smoothing}", flush=True)
2536
else:
2637
label_smoothing = 0
2738

2839
self.label_smoothing = label_smoothing
40+
41+
self.reduction = reduction
42+
self.ignore_index = ignore_index
43+
self.op_type = op_type
44+
45+
assert self.reduction in [
46+
"mean",
47+
"none",
48+
], f"Only support reduction is mean/none, but the passed in reduction is {self.reduction}"
49+
50+
# In order to facilitate the calculation of loss for different datasets, we set reduction as 'none',
51+
# and do loss reduction ourselves.
2952
self.loss_fn = new_cross_entropy(
30-
reduction="mean",
31-
label_smoothing=self.label_smoothing,
53+
op_type=op_type,
54+
ignore_index=ignore_index,
55+
label_smoothing=label_smoothing,
3256
parallel_output=parallel_output,
33-
inplace_backward=True,
57+
inplace_backward=inplace_backward,
58+
reduction="none",
3459
)
3560

3661
def forward(self, *args):
@@ -44,9 +69,18 @@ def forward(self, *args):
4469
raise RuntimeError(f"The number of criterion inputs are:{len(args)}")
4570
shift_logits = logits.contiguous().view(-1, logits.size(-1))
4671
shift_labels = labels.contiguous().view(-1)
47-
loss = self.loss_fn(
48-
shift_logits, shift_labels
49-
) # There is no need to consider the ignore_index problem here, because the loss calculation will be
50-
# calculated through the calculation range, and -100 must be outside this range, so there is no problem
72+
73+
with torch.autocast(device_type=internlm_accelerator.get_backend_name()):
74+
loss_list = self.loss_fn(
75+
shift_logits, shift_labels
76+
) # There is no need to consider the ignore_index problem here, because the loss calculation will be
77+
# # calculated through the calculation range, and -100 must be outside this range, so there is no problem
78+
79+
cond = shift_labels != self.ignore_index
80+
if self.reduction == "mean":
81+
# This loss is only for one dp rank.
82+
loss = loss_list.sum() / (cond).sum()
83+
else:
84+
loss = loss_list
5185

5286
return loss

internlm/model/metrics.py

+1
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None:
305305
reduction="none",
306306
parallel_output=gpc.config.model.parallel_output,
307307
inplace_backward=True,
308+
op_type=gpc.config.loss.op_type,
308309
)
309310
self.scatter_sum = scatter_sum_impl
310311

0 commit comments

Comments
 (0)