diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 4037c0317..8d8acc406 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -103,6 +103,20 @@ clip_grad_norm=1.0, ) + +# loss config (dict): +# 1. label_smoothing +# 2. op_type: cross_entropy operator type, we support five types for loss computing, +# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"] +# default is "py_vocab_parallel". +# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss +# "apex_naive": cross_entropy from apex +# "py_naive": self-implemented cross_entropy +# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn +# "py_vocab_parallel": self-implemented vocab parallel cross_entropy +# * op_types that ends with "naive" only support parallel_output=False; +# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported. + loss = dict( label_smoothing=0, moe_loss_coeff=0.1, diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 97758bba4..51741703e 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -98,9 +98,21 @@ clip_grad_norm=1.0, ) -loss = dict( - label_smoothing=0, -) + +# loss config (dict): +# 1. label_smoothing +# 2. op_type: cross_entropy operator type, we support five types for loss computing, +# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"] +# default is "py_vocab_parallel". +# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss +# "apex_naive": cross_entropy from apex +# "py_naive": self-implemented cross_entropy +# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn +# "py_vocab_parallel": self-implemented vocab parallel cross_entropy + +# * op_types that ends with "naive" only support parallel_output=False; +# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported. +loss = dict(label_smoothing=0, op_type="py_vocab_parallel") adam = dict( lr=1e-4, diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index ad68082d0..39c78660b 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -108,8 +108,24 @@ clip_grad_norm=1.0, ) + +# loss config (dict): +# 1. label_smoothing +# 2. op_type: cross_entropy operator type, we support five types for loss computing, +# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"] +# default is "py_vocab_parallel". +# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss +# "apex_naive": cross_entropy from apex +# "py_naive": self-implemented cross_entropy +# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn +# "py_vocab_parallel": self-implemented vocab parallel cross_entropy + +# * op_types that ends with "naive" only support parallel_output=False; +# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported. + loss = dict( label_smoothing=0, + op_type="flash_vocab_parallel", ) adam = dict( diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index d0ef284d4..2b82bc1f4 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -16,7 +16,7 @@ from internlm.data.train_state import get_train_state from internlm.eval.evaluation import evaluate_on_val_dls from internlm.initialize.initialize_trainer import initialize_trainer -from internlm.model.losses.ce_loss import FlashGPTLMLoss +from internlm.model.losses.ce_loss import InternLoss from internlm.model.metrics import AccPerplex from internlm.monitor.monitor import send_alert_message from internlm.train.pipeline import ( @@ -172,9 +172,11 @@ def _read_config(self, config_path: str) -> list: with open(config_path, "r") as f: return f.readlines() - def _initialize_criterion(self) -> FlashGPTLMLoss: - return FlashGPTLMLoss( - parallel_output=gpc.config.model.parallel_output, label_smoothing=gpc.config.loss.label_smoothing + def _initialize_criterion(self) -> InternLoss: + return InternLoss( + parallel_output=gpc.config.model.parallel_output, + label_smoothing=gpc.config.loss.label_smoothing, + op_type=gpc.config.loss.op_type, ) def _initialize_checkpoint_manager( diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 35b3d646c..c8b16516d 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -351,17 +351,6 @@ def args_sanity_check(): if "use_flash_attn" not in gpc.config.model: gpc.config.model._add_item("use_flash_attn", True) - old_parallel_output = gpc.config.model.get("parallel_output", None) - # Try to change user setting - if internlm_accelerator.get_accelerator_backend() is not AcceleratorType.GPU: - gpc.config.model.update({"parallel_output": False}) - if old_parallel_output is True and gpc.is_rank_for_log(): - logger.warning( - "'parallel_output' is converted from 'True' to 'False'." - "Because 'parallel_output' only support by FlashCrossEntropyLoss." - "Please make sure you are using flash attention in cuda device." - ) - if "MoE" in gpc.config.get("model_type", ModelType.INTERNLM.name): if "num_experts" not in model: model._add_item("num_experts", 1) @@ -449,6 +438,9 @@ def args_sanity_check(): ]: gpc.config.parallel.sequence_parallel = True + if gpc.config.model.get("parallel_output", False) is False: + logger.warning("When enable sequence parallel, it recommend to enable parallel_output") + # set default value for weight parallel if gpc.config.parallel["weight"].get("overlap", None) is None: gpc.config.parallel["weight"]["overlap"] = False @@ -583,6 +575,11 @@ def args_sanity_check(): gpc.config.data.use_packed_dataset is False ), "only unpacked data is supported when using 2D sequence parallel." + # loss operator type + loss_cfg = gpc.config.loss + if loss_cfg.get("op_type", None) is None: + loss_cfg._add_item("op_type", "py_vocab_parallel") + def launch( config: Union[str, Path, Config, Dict], diff --git a/internlm/model/losses/__init__.py b/internlm/model/losses/__init__.py index 582878159..5d6c8db35 100644 --- a/internlm/model/losses/__init__.py +++ b/internlm/model/losses/__init__.py @@ -1,5 +1,5 @@ -from .ce_loss import FlashGPTLMLoss +from .ce_loss import InternLoss __all__ = [ - "FlashGPTLMLoss", + "InternLoss", ] diff --git a/internlm/model/losses/ce_loss.py b/internlm/model/losses/ce_loss.py index 69e09d2fc..5b2a380e8 100644 --- a/internlm/model/losses/ce_loss.py +++ b/internlm/model/losses/ce_loss.py @@ -1,36 +1,61 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - +import torch from torch import nn -from internlm.core.context import global_context as gpc +from internlm.accelerator import get_accelerator from internlm.model.ops.cross_entropy import new_cross_entropy -from internlm.utils.logger import get_logger -logger = get_logger(__file__) +internlm_accelerator = get_accelerator() -class FlashGPTLMLoss(nn.Module): - """ - Loss function for flash GPT Language Model. +class InternLoss(nn.Module): + """We use a base class to wrap different CrossEntropy implementations + and unify input and output parameters. + + This class is designed not to rely on gpc, making it easy to transplant. + + Different variants of CrossEntropy, with supporting parallel computation and inplace operations. + + If parallel_output is False, the output will gather head's output, only 'FlashCrossEntropyLoss' and + 'CrossEntropyApexVocabParallel' support it. """ - def __init__(self, parallel_output=True, label_smoothing=0): + def __init__( + self, + parallel_output=False, + ignore_index=-100, + reduction="mean", + label_smoothing=0.0, + inplace_backward=True, + op_type="py_vocab_parallel", + ) -> None: super().__init__() if label_smoothing is not None: if label_smoothing != 0: - if gpc.is_rank_for_log(): - print(f"use label_smoothing: {label_smoothing}") + print(f"use label_smoothing: {label_smoothing}", flush=True) else: label_smoothing = 0 self.label_smoothing = label_smoothing + + self.reduction = reduction + self.ignore_index = ignore_index + self.op_type = op_type + + assert self.reduction in [ + "mean", + "none", + ], f"Only support reduction is mean/none, but the passed in reduction is {self.reduction}" + + # In order to facilitate the calculation of loss for different datasets, we set reduction as 'none', + # and do loss reduction ourselves. self.loss_fn = new_cross_entropy( - reduction="mean", - label_smoothing=self.label_smoothing, + op_type=op_type, + ignore_index=ignore_index, + label_smoothing=label_smoothing, parallel_output=parallel_output, - inplace_backward=True, + inplace_backward=inplace_backward, + reduction="none", ) def forward(self, *args): @@ -44,9 +69,18 @@ def forward(self, *args): raise RuntimeError(f"The number of criterion inputs are:{len(args)}") shift_logits = logits.contiguous().view(-1, logits.size(-1)) shift_labels = labels.contiguous().view(-1) - loss = self.loss_fn( - shift_logits, shift_labels - ) # There is no need to consider the ignore_index problem here, because the loss calculation will be - # calculated through the calculation range, and -100 must be outside this range, so there is no problem + + with torch.autocast(device_type=internlm_accelerator.get_backend_name()): + loss_list = self.loss_fn( + shift_logits, shift_labels + ) # There is no need to consider the ignore_index problem here, because the loss calculation will be + # # calculated through the calculation range, and -100 must be outside this range, so there is no problem + + cond = shift_labels != self.ignore_index + if self.reduction == "mean": + # This loss is only for one dp rank. + loss = loss_list.sum() / (cond).sum() + else: + loss = loss_list return loss diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index af52858f4..a7f6c9668 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -305,6 +305,7 @@ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None: reduction="none", parallel_output=gpc.config.model.parallel_output, inplace_backward=True, + op_type=gpc.config.loss.op_type, ) self.scatter_sum = scatter_sum_impl diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/ops/cross_entropy.py index 82a2da70d..99bf1e047 100644 --- a/internlm/model/ops/cross_entropy.py +++ b/internlm/model/ops/cross_entropy.py @@ -6,354 +6,131 @@ This file implements support for the cross entropy operators. """ +from enum import Enum + import torch -import torch.distributed as dist from torch import nn from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.ops.cross_entropy_ops import ( + CrossEntropyApexVocabParallel, + CrossEntropyLossApex, + CrossEntropyPython, +) from internlm.utils.logger import get_logger logger = get_logger(__file__) internlm_accelerator = get_accelerator() -# Adapted from https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/core/ \ -# sequence_parallel/cross_entropy.py -class _VocabSequenceParallelCrossEntropy(torch.autograd.Function): - """ - Cross Entropy module for isp. - """ - - @staticmethod - def forward(ctx, vocab_seq_parallel_logits, target, reduction, label_smoothing=0.0): # pylint: disable=W0613 - sp_size = gpc.get_world_size(ParallelMode.TENSOR) - - # reshape - # vocab_seq_parallel_logits: [B * (S/P), V] -> [B, S/P, V] - # target: [B * S/P] -> [B, S/P] - bsz = gpc.config.data.micro_bsz if gpc.config.data.use_packed_dataset is False else 1 - vocab_seq_parallel_logits = vocab_seq_parallel_logits.view(bsz, -1, gpc.config.model.vocab_size) - target = target.view(bsz, -1) - - # transpose - # vocab_seq_parallel_logits: [B, S/P, V] -> [S/P, B, V] - # target: [B, S/P] -> [S/P, B] - # return: [S, B] - vocab_seq_parallel_logits = vocab_seq_parallel_logits.transpose(0, 1).contiguous() - target = target.transpose(0, 1).contiguous() - - ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_size - batch_size = vocab_seq_parallel_logits.size(1) - - # Need softmax for backward - softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1) - ctx.vocab_size = vocab_seq_parallel_logits.size(2) - loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction="none") - - loss_all = torch.empty( - ctx.seqlen, batch_size, dtype=vocab_seq_parallel_logits.dtype, device=vocab_seq_parallel_logits.device - ) +def average_losses_across_data_parallel_group(losses): + """Reduce a tensor of losses across all GPUs.""" + averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) + torch.distributed.all_reduce(averaged_losses, group=gpc.get_group(ParallelMode.DATA)) + averaged_losses = averaged_losses / gpc.get_world_size(ParallelMode.DATA) - torch.distributed.all_gather_into_tensor(loss_all, loss, group=gpc.get_group(ParallelMode.TENSOR)) + return averaged_losses - # [s b] => [b, s] - loss_all = loss_all.transpose(0, 1).contiguous() - ctx.save_for_backward(softmax, target) +class CrossEntropyOpType(Enum): + torch_naive = 1 # CrossEntropy from torch + flash_vocab_parallel = 2 # VocabParallel CorssEntropy from flash_attn + apex_naive = 3 # CrossEntropy from apex + py_vocab_parallel = 4 # self-implemented VocabParallel CrossEntropy + py_naive = 5 # self-implemented CrossEntropy + # sequence_parallel = 6 # self-implemented SequenceParallel CrossEntropy - return loss_all - @staticmethod - def backward(ctx, grad_output): - softmax, target = ctx.saved_tensors +cross_entropy_op_name_map = { + "torch_naive": CrossEntropyOpType.torch_naive, + "flash_vocab_parallel": CrossEntropyOpType.flash_vocab_parallel, + "apex_naive": CrossEntropyOpType.apex_naive, + "py_vocab_parallel": CrossEntropyOpType.py_vocab_parallel, + "py_naive": CrossEntropyOpType.py_naive, + # "sequence_parallel": CrossEntropyOpType.sequence_parallel, +} - # transpose - grad_output = grad_output.transpose(0, 1).contiguous() - step_seqlen = ctx.seqlen // gpc.get_world_size(ParallelMode.TENSOR) - sp_rank = gpc.get_local_rank(ParallelMode.TENSOR) - grad_output_part = grad_output[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1), :] +# TODO: ops是否需要实现更加统一的形式 +def new_cross_entropy( + op_type: str = "py_vocab_parallel", + ignore_index: int = -100, + label_smoothing: float = 0, + parallel_output: bool = False, + inplace_backward: bool = True, + reduction: str = "none", +): + try: + op_type = cross_entropy_op_name_map[op_type] + except KeyError: + raise KeyError(f"op_type only support: {cross_entropy_op_name_map.keys()}") - grad_input = softmax - grad_2d = grad_input.view(-1, ctx.vocab_size) - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + if internlm_accelerator.get_accelerator_backend() is not AcceleratorType.GPU: + assert op_type in [ + CrossEntropyOpType.torch_naive, + CrossEntropyOpType.py_vocab_parallel, + ], "no-GPU env only support 'torch_naive' or 'py_vocab_parallel loss function" - grad_2d[arange_1d, target.view(-1)] -= 1 - grad_input.mul_(grad_output_part.unsqueeze(dim=-1)) + if op_type == CrossEntropyOpType.torch_naive: - # transpose - grad_input = grad_input.transpose(0, 1).contiguous() - # reshape - grad_input = grad_input.view(-1, gpc.config.model.vocab_size) + assert parallel_output is False, ( + "'torch_naive' (nn.CrossEntropyLoss) don't support parallel_output, " + "try use 'flash_vocab_parallel' or 'py_vocab_parallel'" + ) - return grad_input, None, None + return nn.CrossEntropyLoss(reduction=reduction, label_smoothing=label_smoothing, ignore_index=ignore_index) + elif op_type == CrossEntropyOpType.flash_vocab_parallel: -def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): - return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing) + assert gpc.get_group(ParallelMode.TENSOR) is not None, "The process group should not be None." + try: + from flash_attn.losses.cross_entropy import ( + CrossEntropyLoss as FlashCrossEntropyLoss, + ) -def average_losses_across_data_parallel_group(losses): - """Reduce a tensor of losses across all GPUs.""" - averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) - torch.distributed.all_reduce(averaged_losses, group=gpc.get_group(ParallelMode.DATA)) - averaged_losses = averaged_losses / gpc.get_world_size(ParallelMode.DATA) + flash_cross_entropy_impl = True + except (ModuleNotFoundError, ImportError): + flash_cross_entropy_impl = False - return averaged_losses + assert ( + gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl + ), "Only flash cross entropy support parallel_output" + assert ( + internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU + ), "flash cross entropy only support gpu backend" -class VocabSequenceParallelCrossEntropyLoss(nn.Module): - """ - Cross Entropy module for isp. - """ - - def __init__( - self, - ignore_index: int = -100, - reduction: str = "mean", - label_smoothing: float = 0, - process_group=None, - ): - super().__init__() - if reduction not in ["mean", "none"]: - raise NotImplementedError("Only support reduction = 'mean' or 'none'") - self.ignore_index = ignore_index - self.reduction = reduction - self.label_smoothing = label_smoothing - self.process_group = process_group - - def loss_mean_func(self, output_tensor): - losses = output_tensor.float() - loss = torch.sum(losses.view(-1)) / losses.numel() - - # TODO: allreduce loss in dp group - - return loss - - def forward(self, _input, target): - assert _input.is_cuda and target.is_cuda - - _loss_list = vocab_sequence_parallel_cross_entropy(_input, target, self.label_smoothing) - - if self.reduction == "mean": - loss = self.loss_mean_func(_loss_list) - return loss - - return _loss_list.view(-1) - - -class _VocabParallelCrossEntropy(torch.autograd.Function): - """Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py - Supports vocab parallel loss calculation, but does not support inplace backward. - NOTE: This class is different from the original Apex implementation. Apex will calculate the loss of - ignore_index and flashCrossEntropy will set it to 0. InterEvo adapts the second approach. - """ - - @staticmethod - @internlm_accelerator.amp.custom_fwd - def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0, process_group=None): - # Maximum value along vocab dimension across all GPUs. - logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - if process_group is not None and dist.get_world_size(process_group) > 1: - torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) - # Subtract the maximum value. - vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) - - # Get the partition's vocab indecies - # get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size - partition_vocab_size = vocab_parallel_logits.size()[-1] - if process_group is not None and dist.get_world_size(process_group) > 1: - rank = dist.get_rank(process_group) - # world_size = dist.get_world_size(process_group) - part_len = vocab_parallel_logits.shape[-1] - vocab_start_index, vocab_end_index = part_len * rank, part_len * (rank + 1) - else: - vocab_start_index, vocab_end_index = 0, vocab_parallel_logits.shape[-1] - - # vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - ignore_mask = target == -100 - masked_target = target.clone() - vocab_start_index - masked_target[target_mask] = 0 - - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits_1d = predicted_logits_1d.clone().contiguous() - predicted_logits = predicted_logits_1d.view_as(target) - predicted_logits[target_mask] = 0.0 - - # All reduce is needed to get the chunks from other GPUs. - if process_group is not None and dist.get_world_size(process_group) > 1: - torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) - - # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = vocab_parallel_logits - torch.exp(vocab_parallel_logits, out=exp_logits) - sum_exp_logits = exp_logits.sum(dim=-1) - - if process_group is not None and dist.get_world_size(process_group) > 1: - torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) - - # Normalize and optionally smooth logits - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - - # Loss = log(sum(exp(logits))) - predicted-logit. - sum_exp_logits = torch.log(sum_exp_logits) - loss = sum_exp_logits - predicted_logits - loss[ignore_mask] = 0.0 - - vocab_size = exp_logits.size(-1) - if label_smoothing > 0: - r""" - We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. - = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) - = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i - = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i - = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i - = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K - From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py - """ - assert 1.0 > label_smoothing > 0.0 - smoothing = label_smoothing * vocab_size / (vocab_size - 1) - - # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. - log_probs = torch.log(exp_logits) - mean_log_probs = log_probs.mean(dim=-1) - loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs - - ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size - # Store softmax, target-mask and masked-target for backward pass. - ctx.save_for_backward(exp_logits, target_mask, masked_target_1d, ignore_mask) - - return loss - - @staticmethod - @internlm_accelerator.amp.custom_bwd - def backward(ctx, grad_output): - - # Retreive tensors from the forward path. - softmax, target_mask, masked_target_1d, ignore_mask = ctx.saved_tensors - label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size - - # All the inputs have softmax as thier gradient. - grad_input = softmax # s_{k} - # For simplicity, work with the 2D gradient. - partition_vocab_size = softmax.size()[-1] - grad_2d = grad_input.view(-1, partition_vocab_size) - - # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - - softmax_update = 1.0 - target_mask.view(-1).float() - - if label_smoothing > 0: - smoothing = label_smoothing * vocab_size / (vocab_size - 1) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update - average_grad = 1 / vocab_size - grad_2d[arange_1d, :] -= smoothing * average_grad - else: - grad_2d[arange_1d, masked_target_1d] -= softmax_update - - # Finally elementwise multiplication with the output gradients. - grad_input.mul_(grad_output.unsqueeze(dim=-1)) - grad_input[ignore_mask] = 0.0 # set ignore token loss as 0. - - return grad_input, None, None, None - - -class CrossEntropyApexVocabParallel(nn.Module): - """Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py - Supports vocab parallel loss calculation, but does not support inplace backward. - """ - - def __init__( - self, ignore_index=-100, reduction="mean", label_smoothing=0.0, process_group=None, inplace_backward=False - ): - super().__init__() - if reduction not in ["mean", "none"]: - raise NotImplementedError("Only support reduction = 'mean' or 'none'") - assert inplace_backward is False, "does not support inplace backward" - self.ignore_index = ignore_index - self.reduction = reduction - self.label_smoothing = label_smoothing - self.process_group = process_group - - def forward(self, vocab_parallel_logits, target): - # assert vocab_parallel_logits.is_cuda and vocab_parallel_logits.is_cuda - - # SoftmaxCrossEntropyLoss implicitly casts to float - loss = _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, self.label_smoothing, self.process_group) - if self.reduction == "mean": - return loss.sum() / (target != self.ignore_index).sum() - else: - return loss - - -def flash_loss( - ignore_index=-100, - reduction="mean", - label_smoothing=0.0, - process_group=None, - inplace_backward=False, # pylint:disable=W0613 -): - try: - from flash_attn.losses.cross_entropy import ( - CrossEntropyLoss as FlashCrossEntropyLoss, + logger.warning( + "You are using flash_attn cross_entropy operators, \ + which may result loss divergency in long sequence." ) - flash_cross_entropy_impl = True - except (ModuleNotFoundError, ImportError): - flash_cross_entropy_impl = False - - assert ( - gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl - ), "Only flash cross entropy support parallel_output" - - assert ( - internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU - ), "flash cross entropy only support gpu backend" + return FlashCrossEntropyLoss( + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, + process_group=gpc.get_group(ParallelMode.TENSOR), + inplace_backward=inplace_backward, + ) - return FlashCrossEntropyLoss( - ignore_index=ignore_index, - reduction=reduction, - label_smoothing=label_smoothing, - process_group=process_group, - ) + elif op_type == CrossEntropyOpType.apex_naive: + assert parallel_output is False, ( + "'apex_naive' (nn.CrossEntropyLoss) can'ts support parallel_output," + "try use 'flash_vocab_parallel' or 'py_vocab_parallel'" + ) + return CrossEntropyLossApex( + ignore_index=ignore_index, + reduction=reduction, + inplace_backward=inplace_backward, + label_smoothing=label_smoothing, + ) -# TODO: ops是否需要实现更加统一的形式 -def new_cross_entropy( - ignore_index: int = -100, - reduction: str = "mean", - label_smoothing: float = 0, - parallel_output: bool = False, - **kwargs, -): - # if is_using_isp() and parallel_output: - # if gpc.is_rank_for_log(): - # logger.warning("Use VocabSequenceParallelCrossEntropyLoss.") - # return VocabSequenceParallelCrossEntropyLoss( - # ignore_index=ignore_index, - # reduction=reduction, - # label_smoothing=label_smoothing, - # process_group=gpc.get_group(ParallelMode.TENSOR), - # ) - - if parallel_output: - # return flash_loss( - # ignore_index=ignore_index, - # reduction=reduction, - # label_smoothing=label_smoothing, - # process_group=gpc.get_group(ParallelMode.TENSOR), - # ) + elif op_type == CrossEntropyOpType.py_vocab_parallel: + assert gpc.get_group(ParallelMode.TENSOR) is not None, "The process group should not be None." return CrossEntropyApexVocabParallel( ignore_index=ignore_index, @@ -361,13 +138,13 @@ def new_cross_entropy( label_smoothing=label_smoothing, process_group=gpc.get_group(ParallelMode.TENSOR), ) - else: - if gpc.is_rank_for_log(): - logger.warning( - "Use nn.CrossEntropyLoss rather than flashattn CrossEntropyLoss." - "parallel_output must be set false. Please note this!" - ) - kwargs.pop("inplace_backward", None) - return nn.CrossEntropyLoss( - ignore_index=ignore_index, reduction=reduction, label_smoothing=label_smoothing, **kwargs + + elif op_type == CrossEntropyOpType.py_naive: + assert parallel_output is False, ( + "'py_naive' (nn.CrossEntropyLoss) don't support parallel_output," + "try use 'flash_vocab_parallel' or 'py_vocab_parallel'" ) + return CrossEntropyPython(ignore_index=ignore_index, reduction=reduction) + + else: + raise RuntimeError(f"unkown loss function type: {op_type}") diff --git a/internlm/model/ops/cross_entropy_ops/__init__.py b/internlm/model/ops/cross_entropy_ops/__init__.py new file mode 100644 index 000000000..1f4b6630d --- /dev/null +++ b/internlm/model/ops/cross_entropy_ops/__init__.py @@ -0,0 +1,11 @@ +from .apex_naive_loss import CrossEntropyLossApex +from .py_naive_loss import CrossEntropyPython +from .py_vocab_parallel_loss import CrossEntropyApexVocabParallel +from .sequence_parallel_loss import VocabSequenceParallelCrossEntropyLoss + +__all__ = [ + "CrossEntropyLossApex", + "CrossEntropyPython", + "CrossEntropyApexVocabParallel", + "VocabSequenceParallelCrossEntropyLoss", +] diff --git a/internlm/model/ops/cross_entropy_ops/apex_naive_loss.py b/internlm/model/ops/cross_entropy_ops/apex_naive_loss.py new file mode 100644 index 000000000..139f20a24 --- /dev/null +++ b/internlm/model/ops/cross_entropy_ops/apex_naive_loss.py @@ -0,0 +1,77 @@ +import torch +from torch import nn + +from internlm.accelerator import get_accelerator + +try: + import xentropy_cuda_lib +except (ImportError, ModuleNotFoundError): + has_xentropy_cuda_lib = False +else: + has_xentropy_cuda_lib = True + + +internlm_accelerator = get_accelerator() + + +class SoftmaxCrossEntropyLossFn(torch.autograd.Function): + """ + Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py + Inplace backward is supported, but loss calculation of vocab parallel is not supported. + NOTE: it should be noted that when the pack_length exceeds 40K, the loss will not decrease. + """ + + @staticmethod + @internlm_accelerator.amp.custom_fwd + def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False): + losses, max_log_sum_exp = xentropy_cuda_lib.forward(logits, labels, smoothing) + losses.masked_fill_(labels == padding_idx, 0) + ctx.save_for_backward(logits, max_log_sum_exp, labels) + ctx.smoothing = smoothing + ctx.padding_idx = padding_idx + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + @internlm_accelerator.amp.custom_bwd + def backward(ctx, grad_loss): + logits, max_log_sum_exp, labels = ctx.saved_tensors + if not grad_loss.is_contiguous(): + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels == ctx.padding_idx, 0) + grad_logits = xentropy_cuda_lib.backward( + grad_loss, logits, max_log_sum_exp, labels, ctx.smoothing, ctx.inplace_backward + ) + return grad_logits, None, None, None, None + + +class CrossEntropyLossApex(nn.Module): + """ + Inplace backward is supported, but loss calculation of vocab parallel is not supported. + NOTE: it should be noted that when the pack_length exceeds 40K, the loss will not decrease. + """ + + def __init__(self, ignore_index=-100, reduction="mean", label_smoothing=0.0, inplace_backward=False): + super().__init__() + if reduction not in ["mean", "none"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + + assert ( + has_xentropy_cuda_lib is True + ), "The 'xentropy_cuda_lib' package which CrossEntropyLossApex needed was not found in your environment!" + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + + def forward(self, logits, target): + # assert logits.is_cuda and target.is_cuda + + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = SoftmaxCrossEntropyLossFn.apply( + logits, target, self.label_smoothing, self.ignore_index, self.inplace_backward + ) + if self.reduction == "mean": + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/internlm/model/ops/cross_entropy_ops/py_naive_loss.py b/internlm/model/ops/cross_entropy_ops/py_naive_loss.py new file mode 100644 index 000000000..f391933f5 --- /dev/null +++ b/internlm/model/ops/cross_entropy_ops/py_naive_loss.py @@ -0,0 +1,83 @@ +import torch +from torch import nn + +from internlm.accelerator import get_accelerator + +internlm_accelerator = get_accelerator() + + +class CrossEntropyWriteInPython(torch.autograd.Function): + """baseline for unit test.""" + + @staticmethod + @internlm_accelerator.amp.custom_fwd + def forward(ctx, logits, target, ignore_idx): + # (1) cal mask + ignore_mask = target == ignore_idx + target[ignore_mask] = 0 + + # (2) safe softmax for logist + logits_max = torch.max(logits, dim=-1)[0] + logits = logits - logits_max.unsqueeze(dim=-1) + + # (3) cal predicted_logits + vocab_size = logits.shape[-1] + logits_2d = logits.view(-1, vocab_size) + target = target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits = logits_2d[arange_1d, target].clone().contiguous().view_as(target) + + # (4) softmax + exp_logits = logits + torch.exp(logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + + # (5) Normalize and optionally smooth logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + # (6) cal log + sum_exp_logits = torch.log(sum_exp_logits) + + # (7) cal loss + loss = sum_exp_logits - predicted_logits + + # (8) apply ignore_mask + loss[ignore_mask] = 0.0 + ctx.save_for_backward(exp_logits, target, ignore_mask) + return loss + + @staticmethod + @internlm_accelerator.amp.custom_bwd + def backward(ctx, grad_output): + # The deriving of cross entropy ref: + # https://shivammehta25.github.io/posts/deriving-categorical-cross-entropy-and-softmax/ + softmax, target, ignore_mask = ctx.saved_tensors + + # Add the gradient from matching classes(which is indicate by target). + grad_input = softmax + grad_2d = grad_input.view(-1, softmax.shape[-1]) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, target] -= 1.0 + + grad_input.mul_(grad_output.unsqueeze(dim=-1)) # elementwise multiplication + grad_input[ignore_mask] = 0.0 # set ignore token loss as 0. + + return grad_input, None, None, None + + +class CrossEntropyPython(nn.Module): + """ + Baseline for unit test. Please do not use this class directly. + """ + + def __init__(self, ignore_index=-100, reduction="mean"): + super().__init__() + self.ignore_index = ignore_index + self.reduction = reduction + + def forward(self, logits, target): + loss = CrossEntropyWriteInPython.apply(logits, target, self.ignore_index) + if self.reduction == "mean": + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/internlm/model/ops/cross_entropy_ops/py_vocab_parallel_loss.py b/internlm/model/ops/cross_entropy_ops/py_vocab_parallel_loss.py new file mode 100644 index 000000000..6f5457c85 --- /dev/null +++ b/internlm/model/ops/cross_entropy_ops/py_vocab_parallel_loss.py @@ -0,0 +1,160 @@ +import torch +import torch.distributed as dist +from torch import nn + +from internlm.accelerator import get_accelerator + +internlm_accelerator = get_accelerator() + + +class _VocabParallelCrossEntropy(torch.autograd.Function): + """Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py + Supports vocab parallel loss calculation, but does not support inplace backward. + NOTE: This class is different from the original Apex implementation. Apex will calculate the loss of + ignore_index and flashCrossEntropy will set it to 0. InterEvo adapts the second approach. + """ + + @staticmethod + @internlm_accelerator.amp.custom_fwd + def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0, process_group=None): + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + if process_group is not None and dist.get_world_size(process_group) > 1: + torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) + # Subtract the maximum value. + vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) + + # Get the partition's vocab indecies + # get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + if process_group is not None and dist.get_world_size(process_group) > 1: + rank = dist.get_rank(process_group) + # world_size = dist.get_world_size(process_group) + part_len = vocab_parallel_logits.shape[-1] + vocab_start_index, vocab_end_index = part_len * rank, part_len * (rank + 1) + else: + vocab_start_index, vocab_end_index = 0, vocab_parallel_logits.shape[-1] + + # vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + ignore_mask = target == -100 + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + + # All reduce is needed to get the chunks from other GPUs. + if process_group is not None and dist.get_world_size(process_group) > 1: + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + + if process_group is not None and dist.get_world_size(process_group) > 1: + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Normalize and optionally smooth logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + # Loss = log(sum(exp(logits))) - predicted-logit. + sum_exp_logits = torch.log(sum_exp_logits) + loss = sum_exp_logits - predicted_logits + loss[ignore_mask] = 0.0 + + vocab_size = exp_logits.size(-1) + if label_smoothing > 0: + r""" + We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. + = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) + = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i + = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K + From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py + """ + assert 1.0 > label_smoothing > 0.0 + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + + # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. + log_probs = torch.log(exp_logits) + mean_log_probs = log_probs.mean(dim=-1) + loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs + + ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size + # Store softmax, target-mask and masked-target for backward pass. + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d, ignore_mask) + + return loss + + @staticmethod + @internlm_accelerator.amp.custom_bwd + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d, ignore_mask = ctx.saved_tensors + label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size + + # All the inputs have softmax as thier gradient. + grad_input = softmax # s_{k} + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + + softmax_update = 1.0 - target_mask.view(-1).float() + + if label_smoothing > 0: + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update + average_grad = 1 / vocab_size + grad_2d[arange_1d, :] -= smoothing * average_grad + else: + grad_2d[arange_1d, masked_target_1d] -= softmax_update + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + grad_input[ignore_mask] = 0.0 # set ignore token loss as 0. + + return grad_input, None, None, None + + +class CrossEntropyApexVocabParallel(nn.Module): + """Adapt from: https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py + Supports vocab parallel loss calculation, but does not support inplace backward. + """ + + def __init__( + self, ignore_index=-100, reduction="mean", label_smoothing=0.0, process_group=None, inplace_backward=False + ): + super().__init__() + if reduction not in ["mean", "none"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + assert inplace_backward is False, "does not support inplace backward" + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.process_group = process_group + + def forward(self, vocab_parallel_logits, target): + # assert vocab_parallel_logits.is_cuda and vocab_parallel_logits.is_cuda + + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, self.label_smoothing, self.process_group) + if self.reduction == "mean": + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py b/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py new file mode 100644 index 000000000..2072944f8 --- /dev/null +++ b/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py @@ -0,0 +1,121 @@ +import torch +from torch import nn + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc + + +# Adapted from https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/core/ \ +# sequence_parallel/cross_entropy.py +class _VocabSequenceParallelCrossEntropy(torch.autograd.Function): + """ + Cross Entropy module for isp. + """ + + @staticmethod + def forward(ctx, vocab_seq_parallel_logits, target, reduction, label_smoothing=0.0): # pylint: disable=W0613 + sp_size = gpc.get_world_size(ParallelMode.TENSOR) + + # reshape + # vocab_seq_parallel_logits: [B * (S/P), V] -> [B, S/P, V] + # target: [B * S/P] -> [B, S/P] + bsz = gpc.config.data.micro_bsz if gpc.config.data.use_packed_dataset is False else 1 + vocab_seq_parallel_logits = vocab_seq_parallel_logits.view(bsz, -1, gpc.config.model.vocab_size) + target = target.view(bsz, -1) + + # transpose + # vocab_seq_parallel_logits: [B, S/P, V] -> [S/P, B, V] + # target: [B, S/P] -> [S/P, B] + # return: [S, B] + vocab_seq_parallel_logits = vocab_seq_parallel_logits.transpose(0, 1).contiguous() + target = target.transpose(0, 1).contiguous() + + ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_size + batch_size = vocab_seq_parallel_logits.size(1) + + # Need softmax for backward + softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1) + ctx.vocab_size = vocab_seq_parallel_logits.size(2) + loss = torch.nn.functional.nll_loss(softmax.log().view(-1, ctx.vocab_size), target.view(-1), reduction="none") + + loss_all = torch.empty( + ctx.seqlen, batch_size, dtype=vocab_seq_parallel_logits.dtype, device=vocab_seq_parallel_logits.device + ) + + torch.distributed.all_gather_into_tensor(loss_all, loss, group=gpc.get_group(ParallelMode.TENSOR)) + + # [s b] => [b, s] + loss_all = loss_all.transpose(0, 1).contiguous() + + ctx.save_for_backward(softmax, target) + + return loss_all + + @staticmethod + def backward(ctx, grad_output): + softmax, target = ctx.saved_tensors + + # transpose + grad_output = grad_output.transpose(0, 1).contiguous() + + step_seqlen = ctx.seqlen // gpc.get_world_size(ParallelMode.TENSOR) + sp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + grad_output_part = grad_output[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1), :] + + grad_input = softmax + grad_2d = grad_input.view(-1, ctx.vocab_size) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + + grad_2d[arange_1d, target.view(-1)] -= 1 + grad_input.mul_(grad_output_part.unsqueeze(dim=-1)) + + # transpose + grad_input = grad_input.transpose(0, 1).contiguous() + # reshape + grad_input = grad_input.view(-1, gpc.config.model.vocab_size) + + return grad_input, None, None + + +def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): + return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing) + + +class VocabSequenceParallelCrossEntropyLoss(nn.Module): + """ + Cross Entropy module for isp. + """ + + def __init__( + self, + ignore_index: int = -100, + reduction: str = "mean", + label_smoothing: float = 0, + process_group=None, + ): + super().__init__() + if reduction not in ["mean", "none"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.process_group = process_group + + def loss_mean_func(self, output_tensor): + losses = output_tensor.float() + loss = torch.sum(losses.view(-1)) / losses.numel() + + # TODO: allreduce loss in dp group + + return loss + + def forward(self, _input, target): + assert _input.is_cuda and target.is_cuda + + _loss_list = vocab_sequence_parallel_cross_entropy(_input, target, self.label_smoothing) + + if self.reduction == "mean": + loss = self.loss_mean_func(_loss_list) + return loss + + return _loss_list.view(-1) diff --git a/tests/test_infer/test_trainer_generate.py b/tests/test_infer/test_trainer_generate.py index 3ccbfb54d..537a40777 100644 --- a/tests/test_infer/test_trainer_generate.py +++ b/tests/test_infer/test_trainer_generate.py @@ -10,7 +10,7 @@ from internlm.core.trainer import TrainState, Trainer # noqa: E402 from internlm.data import build_train_loader_with_data_type # noqa: E402 from internlm.initialize import initialize_distributed_env # noqa: E402 -from internlm.model.losses import FlashGPTLMLoss # noqa: E402 +from internlm.model.losses import InternLoss # noqa: E402 from internlm.train import ( # noqa: E402 get_scheduler_hooks, initialize_model, @@ -25,7 +25,7 @@ def setup_generator(config, tokenizer): model = initialize_model() isp_communicator = initialize_parallel_communicator(model) - criterion = FlashGPTLMLoss() + criterion = InternLoss() # initialize the train data loader train_dl, _ = build_train_loader_with_data_type() diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index 48b97bfa8..ba7f0118d 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -15,7 +15,7 @@ from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type from internlm.initialize.launch import args_sanity_check -from internlm.model.losses import FlashGPTLMLoss +from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook from internlm.train import ( initialize_model, @@ -175,7 +175,7 @@ def train_check_output(args): _ = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=False, label_smoothing=gpc.config.loss.label_smoothing) + criterion = InternLoss(parallel_output=False, label_smoothing=gpc.config.loss.label_smoothing) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) diff --git a/tests/test_training/test_load_ckpt_loss.py b/tests/test_training/test_load_ckpt_loss.py index e68905176..ddbb24a08 100644 --- a/tests/test_training/test_load_ckpt_loss.py +++ b/tests/test_training/test_load_ckpt_loss.py @@ -38,7 +38,7 @@ args_sanity_check, ) from internlm.model.losses import ( # noqa: E402 #pylint: disable=wrong-import-position - FlashGPTLMLoss, + InternLoss, ) from internlm.model.metrics import ( # noqa: E402 #pylint: disable=wrong-import-position AccPerplex, @@ -224,7 +224,7 @@ def train_model(args): _ = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) + criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) # initialize the train and validation data loader train_dl, dataset_types = build_train_loader_with_data_type() diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 2fd8ad4cb..8b506d2d2 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -13,7 +13,7 @@ from internlm.core.trainer import Trainer, TrainState from internlm.data import build_train_loader_with_data_type from internlm.initialize import initialize_distributed_env -from internlm.model.losses import FlashGPTLMLoss +from internlm.model.losses import InternLoss from internlm.train import ( get_scheduler_hooks, initialize_model, @@ -174,7 +174,7 @@ def train( isp_communicator = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=gpc.config.model.parallel_output, label_smoothing=label_smoothing) + criterion = InternLoss(parallel_output=gpc.config.model.parallel_output, label_smoothing=label_smoothing) # initialize the train data loader train_dl, _ = build_train_loader_with_data_type() diff --git a/tests/test_training/test_no_fa_train_temp.py b/tests/test_training/test_no_fa_train_temp.py index f142e503f..5f0782b4b 100644 --- a/tests/test_training/test_no_fa_train_temp.py +++ b/tests/test_training/test_no_fa_train_temp.py @@ -8,7 +8,7 @@ from internlm.core.context import global_context as gpc from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type -from internlm.model.losses import FlashGPTLMLoss +from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, @@ -58,7 +58,7 @@ def train_check(args): isp_communicator = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) + criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) diff --git a/tests/test_training/test_norm_weight.py b/tests/test_training/test_norm_weight.py index 0fd24926f..990b334a6 100644 --- a/tests/test_training/test_norm_weight.py +++ b/tests/test_training/test_norm_weight.py @@ -11,7 +11,7 @@ from internlm.core.context import global_context as gpc from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type -from internlm.model.losses import FlashGPTLMLoss +from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, @@ -78,7 +78,7 @@ def train_check_norm_weight(args): isp_communicator = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) + criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py index 4fa096a5a..13c01b1c5 100644 --- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py +++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py @@ -21,7 +21,7 @@ ) from internlm.eval.evaluation import switch_evaluation_mode from internlm.initialize.launch import args_sanity_check -from internlm.model.losses import FlashGPTLMLoss +from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook from internlm.train import ( initialize_model, @@ -275,7 +275,7 @@ def exam_loss(args): _ = initialize_parallel_communicator(model) # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) + criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) # initialize the train and validation data loader train_dl, dataset_types = build_train_loader_with_data_type() diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index 7926bae5d..c7da6f85c 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -27,7 +27,7 @@ ) from internlm.eval.evaluation import evaluate_on_val_dls # noqa: E402 from internlm.initialize import initialize_distributed_env # noqa: E402 -from internlm.model.losses import FlashGPTLMLoss # noqa: E402 +from internlm.model.losses import InternLoss # noqa: E402 from internlm.model.metrics import AccPerplex, SchedulerMetricHook # noqa: E402 from internlm.monitor import ( # noqa: E402 initialize_monitor_manager, @@ -123,7 +123,7 @@ def main(args): config_lines = f.readlines() # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) + criterion = InternLoss(parallel_output=True, label_smoothing=label_smoothing) # initialize the train and validation data loader train_dl, dataset_types = build_train_loader_with_data_type()