1- #!/usr/bin/env python
2- # -*- encoding: utf-8 -*-
3-
1+ import torch
42from torch import nn
53
6- from internlm .core . context import global_context as gpc
4+ from internlm .accelerator import get_accelerator
75from internlm .model .ops .cross_entropy import new_cross_entropy
8- from internlm .utils .logger import get_logger
96
10- logger = get_logger ( __file__ )
7+ internlm_accelerator = get_accelerator ( )
118
129
13- class FlashGPTLMLoss (nn .Module ):
14- """
15- Loss function for flash GPT Language Model.
10+ class InternLoss (nn .Module ):
11+ """We use a base class to wrap different CrossEntropy implementations
12+ and unify input and output parameters.
13+
14+ This class is designed not to rely on gpc, making it easy to transplant.
15+
16+ Different variants of CrossEntropy, with supporting parallel computation and inplace operations.
17+
18+ If parallel_output is False, the output will gather head's output, only 'FlashCrossEntropyLoss' and
19+ 'CrossEntropyApexVocabParallel' support it.
1620 """
1721
18- def __init__ (self , parallel_output = True , label_smoothing = 0 ):
22+ def __init__ (
23+ self ,
24+ parallel_output = False ,
25+ ignore_index = - 100 ,
26+ reduction = "mean" ,
27+ label_smoothing = 0.0 ,
28+ inplace_backward = True ,
29+ op_type = "py_vocab_parallel" ,
30+ ) -> None :
1931 super ().__init__ ()
2032
2133 if label_smoothing is not None :
2234 if label_smoothing != 0 :
23- if gpc .is_rank_for_log ():
24- print (f"use label_smoothing: { label_smoothing } " )
35+ print (f"use label_smoothing: { label_smoothing } " , flush = True )
2536 else :
2637 label_smoothing = 0
2738
2839 self .label_smoothing = label_smoothing
40+
41+ self .reduction = reduction
42+ self .ignore_index = ignore_index
43+ self .op_type = op_type
44+
45+ assert self .reduction in [
46+ "mean" ,
47+ "none" ,
48+ ], f"Only support reduction is mean/none, but the passed in reduction is { self .reduction } "
49+
50+ # In order to facilitate the calculation of loss for different datasets, we set reduction as 'none',
51+ # and do loss reduction ourselves.
2952 self .loss_fn = new_cross_entropy (
30- reduction = "mean" ,
31- label_smoothing = self .label_smoothing ,
53+ op_type = op_type ,
54+ ignore_index = ignore_index ,
55+ label_smoothing = label_smoothing ,
3256 parallel_output = parallel_output ,
33- inplace_backward = True ,
57+ inplace_backward = inplace_backward ,
58+ reduction = "none" ,
3459 )
3560
3661 def forward (self , * args ):
@@ -44,9 +69,18 @@ def forward(self, *args):
4469 raise RuntimeError (f"The number of criterion inputs are:{ len (args )} " )
4570 shift_logits = logits .contiguous ().view (- 1 , logits .size (- 1 ))
4671 shift_labels = labels .contiguous ().view (- 1 )
47- loss = self .loss_fn (
48- shift_logits , shift_labels
49- ) # There is no need to consider the ignore_index problem here, because the loss calculation will be
50- # calculated through the calculation range, and -100 must be outside this range, so there is no problem
72+
73+ with torch .autocast (device_type = internlm_accelerator .get_backend_name ()):
74+ loss_list = self .loss_fn (
75+ shift_logits , shift_labels
76+ ) # There is no need to consider the ignore_index problem here, because the loss calculation will be
77+ # # calculated through the calculation range, and -100 must be outside this range, so there is no problem
78+
79+ cond = shift_labels != self .ignore_index
80+ if self .reduction == "mean" :
81+ # This loss is only for one dp rank.
82+ loss = loss_list .sum () / (cond ).sum ()
83+ else :
84+ loss = loss_list
5185
5286 return loss
0 commit comments