1
- #!/usr/bin/env python
2
- # -*- encoding: utf-8 -*-
3
-
1
+ import torch
4
2
from torch import nn
5
3
6
- from internlm .core . context import global_context as gpc
4
+ from internlm .accelerator import get_accelerator
7
5
from internlm .model .ops .cross_entropy import new_cross_entropy
8
- from internlm .utils .logger import get_logger
9
6
10
- logger = get_logger ( __file__ )
7
+ internlm_accelerator = get_accelerator ( )
11
8
12
9
13
- class FlashGPTLMLoss (nn .Module ):
14
- """
15
- Loss function for flash GPT Language Model.
10
+ class InternLoss (nn .Module ):
11
+ """We use a base class to wrap different CrossEntropy implementations
12
+ and unify input and output parameters.
13
+
14
+ This class is designed not to rely on gpc, making it easy to transplant.
15
+
16
+ Different variants of CrossEntropy, with supporting parallel computation and inplace operations.
17
+
18
+ If parallel_output is False, the output will gather head's output, only 'FlashCrossEntropyLoss' and
19
+ 'CrossEntropyApexVocabParallel' support it.
16
20
"""
17
21
18
- def __init__ (self , parallel_output = True , label_smoothing = 0 ):
22
+ def __init__ (
23
+ self ,
24
+ parallel_output = False ,
25
+ ignore_index = - 100 ,
26
+ reduction = "mean" ,
27
+ label_smoothing = 0.0 ,
28
+ inplace_backward = True ,
29
+ op_type = "py_vocab_parallel" ,
30
+ ) -> None :
19
31
super ().__init__ ()
20
32
21
33
if label_smoothing is not None :
22
34
if label_smoothing != 0 :
23
- if gpc .is_rank_for_log ():
24
- print (f"use label_smoothing: { label_smoothing } " )
35
+ print (f"use label_smoothing: { label_smoothing } " , flush = True )
25
36
else :
26
37
label_smoothing = 0
27
38
28
39
self .label_smoothing = label_smoothing
40
+
41
+ self .reduction = reduction
42
+ self .ignore_index = ignore_index
43
+ self .op_type = op_type
44
+
45
+ assert self .reduction in [
46
+ "mean" ,
47
+ "none" ,
48
+ ], f"Only support reduction is mean/none, but the passed in reduction is { self .reduction } "
49
+
50
+ # In order to facilitate the calculation of loss for different datasets, we set reduction as 'none',
51
+ # and do loss reduction ourselves.
29
52
self .loss_fn = new_cross_entropy (
30
- reduction = "mean" ,
31
- label_smoothing = self .label_smoothing ,
53
+ op_type = op_type ,
54
+ ignore_index = ignore_index ,
55
+ label_smoothing = label_smoothing ,
32
56
parallel_output = parallel_output ,
33
- inplace_backward = True ,
57
+ inplace_backward = inplace_backward ,
58
+ reduction = "none" ,
34
59
)
35
60
36
61
def forward (self , * args ):
@@ -44,9 +69,18 @@ def forward(self, *args):
44
69
raise RuntimeError (f"The number of criterion inputs are:{ len (args )} " )
45
70
shift_logits = logits .contiguous ().view (- 1 , logits .size (- 1 ))
46
71
shift_labels = labels .contiguous ().view (- 1 )
47
- loss = self .loss_fn (
48
- shift_logits , shift_labels
49
- ) # There is no need to consider the ignore_index problem here, because the loss calculation will be
50
- # calculated through the calculation range, and -100 must be outside this range, so there is no problem
72
+
73
+ with torch .autocast (device_type = internlm_accelerator .get_backend_name ()):
74
+ loss_list = self .loss_fn (
75
+ shift_logits , shift_labels
76
+ ) # There is no need to consider the ignore_index problem here, because the loss calculation will be
77
+ # # calculated through the calculation range, and -100 must be outside this range, so there is no problem
78
+
79
+ cond = shift_labels != self .ignore_index
80
+ if self .reduction == "mean" :
81
+ # This loss is only for one dp rank.
82
+ loss = loss_list .sum () / (cond ).sum ()
83
+ else :
84
+ loss = loss_list
51
85
52
86
return loss
0 commit comments