diff --git a/finetune/run_c3.py b/finetune/run_c3.py index 3bb3d41..dd52d04 100644 --- a/finetune/run_c3.py +++ b/finetune/run_c3.py @@ -160,14 +160,6 @@ def main(): optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_chid.py b/finetune/run_chid.py index d9abd7c..2776ada 100644 --- a/finetune/run_chid.py +++ b/finetune/run_chid.py @@ -179,14 +179,6 @@ def main(): optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_classifier.py b/finetune/run_classifier.py index 9750402..79d096f 100644 --- a/finetune/run_classifier.py +++ b/finetune/run_classifier.py @@ -185,11 +185,7 @@ def train_model(args, model, optimizer, scheduler, src_batch, tgt_batch, seg_bat if torch.cuda.device_count() > 1: loss = torch.mean(loss) - if args.fp16: - with args.amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + loss.backward() if args.use_adv and args.adv_type == "fgm": args.adv_method.attack(epsilon=args.fgm_epsilon) @@ -310,14 +306,6 @@ def main(): args.logger.info("The number of training instances: {}".format(instances_num)) optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_classifier_cv.py b/finetune/run_classifier_cv.py index bac1c2f..114f29e 100644 --- a/finetune/run_classifier_cv.py +++ b/finetune/run_classifier_cv.py @@ -95,13 +95,7 @@ def main(): model = model.to(args.device) load_or_initialize_parameters(args, model) optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp + if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) args.model = model diff --git a/finetune/run_classifier_grid.py b/finetune/run_classifier_grid.py index f0adc9a..4e28513 100644 --- a/finetune/run_classifier_grid.py +++ b/finetune/run_classifier_grid.py @@ -74,13 +74,7 @@ def main(): model = model.to(args.device) load_or_initialize_parameters(args, model) optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp + if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) args.model = model diff --git a/finetune/run_classifier_mt.py b/finetune/run_classifier_mt.py index 0e88b43..dc1894c 100644 --- a/finetune/run_classifier_mt.py +++ b/finetune/run_classifier_mt.py @@ -158,14 +158,6 @@ def main(): optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_classifier_multi_label.py b/finetune/run_classifier_multi_label.py index 953002b..01eb67c 100644 --- a/finetune/run_classifier_multi_label.py +++ b/finetune/run_classifier_multi_label.py @@ -126,11 +126,7 @@ def train_model(args, model, optimizer, scheduler, src_batch, tgt_batch, seg_bat if torch.cuda.device_count() > 1: loss = torch.mean(loss) - if args.fp16: - with args.amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + loss.backward() if args.use_adv and args.adv_type == "fgm": args.adv_method.attack(epsilon=args.fgm_epsilon) @@ -234,14 +230,6 @@ def main(): args.logger.info("The number of training instances: {}".format(instances_num)) optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_classifier_prompt.py b/finetune/run_classifier_prompt.py index 138e1a6..b4f7051 100644 --- a/finetune/run_classifier_prompt.py +++ b/finetune/run_classifier_prompt.py @@ -131,11 +131,7 @@ def train_model(args, model, optimizer, scheduler, src_batch, tgt_batch, seg_bat if torch.cuda.device_count() > 1: loss = torch.mean(loss) - if args.fp16: - with args.amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + loss.backward() optimizer.step() scheduler.step() @@ -257,14 +253,6 @@ def main(): args.logger.info("The number of training instances: {}".format(instances_num)) optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_classifier_siamese.py b/finetune/run_classifier_siamese.py index 674afef..5e67f0a 100644 --- a/finetune/run_classifier_siamese.py +++ b/finetune/run_classifier_siamese.py @@ -181,11 +181,7 @@ def train_model(args, model, optimizer, scheduler, src_batch, tgt_batch, seg_bat if torch.cuda.device_count() > 1: loss = torch.mean(loss) - if args.fp16: - with args.amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + loss.backward() optimizer.step() scheduler.step() @@ -288,14 +284,6 @@ def main(): optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_cmrc.py b/finetune/run_cmrc.py index 696b38a..abdca41 100644 --- a/finetune/run_cmrc.py +++ b/finetune/run_cmrc.py @@ -159,11 +159,7 @@ def train(args, model, optimizer, scheduler, src_batch, seg_batch, start_positio if torch.cuda.device_count() > 1: loss = torch.mean(loss) - if args.fp16: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + loss.backward() optimizer.step() scheduler.step() @@ -394,13 +390,6 @@ def main(): optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer,opt_level=args.fp16_opt_level) - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_dbqa.py b/finetune/run_dbqa.py index 41aec91..4b258b2 100644 --- a/finetune/run_dbqa.py +++ b/finetune/run_dbqa.py @@ -179,14 +179,6 @@ def main(): optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer,opt_level = args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_image_classifier.py b/finetune/run_image_classifier.py index 9bd0d25..7e26c64 100644 --- a/finetune/run_image_classifier.py +++ b/finetune/run_image_classifier.py @@ -149,14 +149,6 @@ def main(): args.logger.info("The number of training instances: {}".format(instances_num)) optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_ner.py b/finetune/run_ner.py index eec05ad..935fceb 100644 --- a/finetune/run_ner.py +++ b/finetune/run_ner.py @@ -145,11 +145,7 @@ def train(args, model, optimizer, scheduler, src_batch, tgt_batch, seg_batch): if torch.cuda.device_count() > 1: loss = torch.mean(loss) - if args.fp16: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + loss.backward() optimizer.step() scheduler.step() @@ -288,13 +284,6 @@ def main(): optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level = args.fp16_opt_level) - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_regression.py b/finetune/run_regression.py index 21b96eb..264c783 100644 --- a/finetune/run_regression.py +++ b/finetune/run_regression.py @@ -147,14 +147,6 @@ def main(): args.logger.info("The number of training instances: {}".format(instances_num)) optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_simcse.py b/finetune/run_simcse.py index 90ca563..fa6e534 100644 --- a/finetune/run_simcse.py +++ b/finetune/run_simcse.py @@ -202,14 +202,6 @@ def main(): optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) @@ -245,11 +237,7 @@ def main(): tgt_batch = torch.arange(similarity_matrix.size(0), device=similarity_matrix.device, dtype=torch.long) loss = nn.CrossEntropyLoss()(similarity_matrix, tgt_batch) - if args.fp16: - with args.amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + loss.backward() optimizer.step() scheduler.step() diff --git a/finetune/run_speech2text.py b/finetune/run_speech2text.py index c031fcd..fb95615 100755 --- a/finetune/run_speech2text.py +++ b/finetune/run_speech2text.py @@ -149,11 +149,7 @@ def train_model(args, model, optimizer, scheduler, src_batch, tgt_in_batch, tgt_ if torch.cuda.device_count() > 1: loss = torch.mean(loss) - if args.fp16: - with args.amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + loss.backward() optimizer.step() scheduler.step() @@ -259,14 +255,6 @@ def main(): optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/finetune/run_text2text.py b/finetune/run_text2text.py index b2c759b..8712d6e 100755 --- a/finetune/run_text2text.py +++ b/finetune/run_text2text.py @@ -141,11 +141,7 @@ def train_model(args, model, optimizer, scheduler, src_batch, tgt_in_batch, tgt_ if torch.cuda.device_count() > 1: loss = torch.mean(loss) - if args.fp16: - with args.amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + loss.backward() optimizer.step() scheduler.step() @@ -262,14 +258,6 @@ def main(): optimizer, scheduler = build_optimizer(args, model) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp - if torch.cuda.device_count() > 1: args.logger.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) diff --git a/models/bloom/175b_config.json b/models/bloom/175b_config.json new file mode 100644 index 0000000..7815fd9 --- /dev/null +++ b/models/bloom/175b_config.json @@ -0,0 +1,21 @@ +{ + "emb_size": 14336, + "feedforward_size": 57344, + "hidden_size": 14336, + "hidden_act": "gelu", + "heads_num": 112, + "layers_num": 70, + "dropout": 0.0, + "data_processor": "lm", + "embedding": ["word"], + "remove_transformer_bias": false, + "has_lmtarget_bias": false, + "remove_embedding_layernorm": true, + "encoder": "transformer", + "mask": "causal", + "layernorm_positioning": "pre", + "target": ["lm"], + "tie_weights": true, + "alibi_position_embedding": true, + "layer_number_scale": true +} \ No newline at end of file diff --git a/models/bloom/1b1_config.json b/models/bloom/1b1_config.json new file mode 100644 index 0000000..6c22799 --- /dev/null +++ b/models/bloom/1b1_config.json @@ -0,0 +1,21 @@ +{ + "emb_size": 1536, + "feedforward_size": 6144, + "hidden_size": 1536, + "hidden_act": "gelu", + "heads_num": 16, + "layers_num": 24, + "dropout": 0.0, + "data_processor": "lm", + "embedding": ["word"], + "remove_transformer_bias": false, + "has_lmtarget_bias": false, + "remove_embedding_layernorm": true, + "encoder": "transformer", + "mask": "causal", + "layernorm_positioning": "pre", + "target": ["lm"], + "tie_weights": true, + "alibi_position_embedding": true, + "layer_number_scale": true +} \ No newline at end of file diff --git a/models/bloom/7b1_config.json b/models/bloom/7b1_config.json new file mode 100644 index 0000000..eaf392f --- /dev/null +++ b/models/bloom/7b1_config.json @@ -0,0 +1,21 @@ +{ + "emb_size": 4096, + "feedforward_size": 16384, + "hidden_size": 4096, + "hidden_act": "gelu", + "heads_num": 32, + "layers_num": 30, + "dropout": 0.0, + "data_processor": "lm", + "embedding": ["word"], + "remove_transformer_bias": false, + "has_lmtarget_bias": false, + "remove_embedding_layernorm": true, + "encoder": "transformer", + "mask": "causal", + "layernorm_positioning": "pre", + "target": ["lm"], + "tie_weights": true, + "alibi_position_embedding": true, + "layer_number_scale": true +} \ No newline at end of file diff --git a/models/llama2/13b_config.json b/models/llama2/13b_config.json new file mode 100644 index 0000000..7ec32fd --- /dev/null +++ b/models/llama2/13b_config.json @@ -0,0 +1,21 @@ +{ + "emb_size": 5120, + "feedforward_size": 13824, + "hidden_size": 5120, + "hidden_act": "silu", + "heads_num": 40, + "layers_num": 40, + "dropout": 0.0, + "data_processor": "lm", + "max_seq_length": 2048, + "embedding": ["word"], + "remove_transformer_bias": true, + "remove_embedding_layernorm": true, + "rotary_position_embedding": true, + "encoder": "transformer", + "feed_forward": "gated", + "mask": "causal", + "layernorm_positioning": "pre", + "layernorm": "rms", + "target": ["lm"] +} \ No newline at end of file diff --git a/models/llama2/70b_config.json b/models/llama2/70b_config.json new file mode 100644 index 0000000..a750241 --- /dev/null +++ b/models/llama2/70b_config.json @@ -0,0 +1,22 @@ +{ + "emb_size": 8192, + "feedforward_size": 28672, + "hidden_size": 8192, + "hidden_act": "silu", + "heads_num": 64, + "local_kv_heads_num": 8, + "layers_num": 80, + "dropout": 0.0, + "data_processor": "lm", + "max_seq_length": 2048, + "embedding": ["word"], + "remove_transformer_bias": true, + "remove_embedding_layernorm": true, + "rotary_position_embedding": true, + "encoder": "transformer", + "feed_forward": "gated", + "mask": "causal", + "layernorm_positioning": "pre", + "layernorm": "rms", + "target": ["lm"] +} \ No newline at end of file diff --git a/models/llama2/7b_config.json b/models/llama2/7b_config.json new file mode 100644 index 0000000..a5c282c --- /dev/null +++ b/models/llama2/7b_config.json @@ -0,0 +1,21 @@ +{ + "emb_size": 4096, + "feedforward_size": 11008, + "hidden_size": 4096, + "hidden_act": "silu", + "heads_num": 32, + "layers_num": 32, + "dropout": 0.0, + "data_processor": "lm", + "max_seq_length": 2048, + "embedding": ["word"], + "remove_transformer_bias": true, + "remove_embedding_layernorm": true, + "rotary_position_embedding": true, + "encoder": "transformer", + "feed_forward": "gated", + "mask": "causal", + "layernorm_positioning": "pre", + "layernorm": "rms", + "target": ["lm"] +} \ No newline at end of file diff --git a/tencentpretrain/encoders/transformer_encoder.py b/tencentpretrain/encoders/transformer_encoder.py index 68673be..e848d12 100644 --- a/tencentpretrain/encoders/transformer_encoder.py +++ b/tencentpretrain/encoders/transformer_encoder.py @@ -2,8 +2,8 @@ import torch.nn as nn from tencentpretrain.utils.rope import precompute_freqs_cis from tencentpretrain.layers.transformer import TransformerLayer -from tencentpretrain.layers.layer_norm import * from tencentpretrain.layers.relative_position_embedding import RelativePositionEmbedding +from tencentpretrain.layers import * class TransformerEncoder(nn.Module): """ @@ -37,12 +37,7 @@ def __init__(self, args): [TransformerLayer(args) for _ in range(self.layers_num)] ) if self.layernorm_positioning == "pre": - if args.layernorm == "t5": - self.layer_norm = T5LayerNorm(args.hidden_size) - elif args.layernorm == "rms": - self.layer_norm = RMSNorm(args.hidden_size) - else: - self.layer_norm = LayerNorm(args.hidden_size) + self.layer_norm = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) if self.relative_position_embedding: self.relative_pos_emb = RelativePositionEmbedding(bidirectional=True, heads_num=args.heads_num, diff --git a/tencentpretrain/layers/__init__.py b/tencentpretrain/layers/__init__.py index e69de29..a0d01fc 100644 --- a/tencentpretrain/layers/__init__.py +++ b/tencentpretrain/layers/__init__.py @@ -0,0 +1,12 @@ +from tencentpretrain.layers.layer_norm import * +from tencentpretrain.layers.multi_headed_attn import * +from tencentpretrain.layers.position_ffn import * +import torch.nn as nn + + +str2layernorm = {"t5": T5LayerNorm, "rms": RMSNorm, "normal": LayerNorm} + +str2feedforward = {"gated": GatedFeedForward, "dense": PositionwiseFeedForward} + +__all__ = ["T5LayerNorm", "RMSNorm", "LayerNorm", "GatedFeedForward", "PositionwiseFeedForward", + "str2layernorm", "str2feedforward"] \ No newline at end of file diff --git a/tencentpretrain/layers/layer_norm.py b/tencentpretrain/layers/layer_norm.py index fd38795..b16f936 100644 --- a/tencentpretrain/layers/layer_norm.py +++ b/tencentpretrain/layers/layer_norm.py @@ -7,16 +7,20 @@ class LayerNorm(nn.Module): Layer Normalization. https://arxiv.org/abs/1607.06450 """ - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-6, eps_inside=False): super(LayerNorm, self).__init__() self.eps = eps + self.eps_inside = eps_inside self.gamma = nn.Parameter(torch.ones(hidden_size)) self.beta = nn.Parameter(torch.zeros(hidden_size)) def forward(self, x): mean = x.mean(-1, keepdim=True) - std = x.std(-1, keepdim=True) - hidden_states = self.gamma * (x-mean) / (std + self.eps) + if self.eps_inside: + std = torch.sqrt(x.var(-1, keepdim=True) + self.eps) + else: + std = x.std(-1, keepdim=True) + self.eps + hidden_states = self.gamma * (x-mean) / std return hidden_states + self.beta diff --git a/tencentpretrain/layers/multi_headed_attn.py b/tencentpretrain/layers/multi_headed_attn.py index e5fcd12..128ff66 100755 --- a/tencentpretrain/layers/multi_headed_attn.py +++ b/tencentpretrain/layers/multi_headed_attn.py @@ -4,20 +4,41 @@ from tencentpretrain.utils.rope import apply_rotary_emb from tencentpretrain.utils.lora import LoraLinear + +def repeat_kv(x: torch.Tensor, repeat_num: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, seq_length, kv_heads_num, head_dim = x.shape + if repeat_num == 1: + return x + + else: + return ( + x[:, :, :, None, :] + .expand(bs, seq_length, kv_heads_num, repeat_num, head_dim) + .reshape(bs, seq_length, kv_heads_num * repeat_num, head_dim) + ) + + class MultiHeadedAttention(nn.Module): """ Each head is a self-attention operation. self-attention refers to https://arxiv.org/pdf/1706.03762.pdf """ - def __init__(self, hidden_size, heads_num, attention_head_size, dropout, has_bias=True, with_scale=True, - lora_params=None): + def __init__(self, hidden_size, heads_num, attention_head_size, local_kv_heads_num, dropout, has_bias=True, with_scale=True, + lora_params=None, layer_number=None): super(MultiHeadedAttention, self).__init__() self.heads_num = heads_num - self.per_head_size = attention_head_size self.with_scale = with_scale self.inner_hidden_size = heads_num * attention_head_size + self.local_kv_heads_num = local_kv_heads_num + + self.kv_embed_dim = self.inner_hidden_size // heads_num * self.local_kv_heads_num + self.num_head_groups = heads_num // self.local_kv_heads_num + assert heads_num >= self.local_kv_heads_num, "heads_num should be greater than or equal to n_local_kv_heads" + assert heads_num % self.local_kv_heads_num == 0, "heads_num should be divisible by n_local_kv_heads" + self.repeat_num = self.heads_num // self.local_kv_heads_num if lora_params is not None: @@ -32,13 +53,19 @@ def __init__(self, hidden_size, heads_num, attention_head_size, dropout, has_bia ) else: self.linear_layers = nn.ModuleList( - [nn.Linear(hidden_size, self.inner_hidden_size, bias=has_bias) for _ in range(3)] + [nn.Linear(hidden_size, self.inner_hidden_size, bias=has_bias) if i==0 else nn.Linear(hidden_size, self.kv_embed_dim, bias=has_bias) for i in range(3)] ) self.dropout = nn.Dropout(dropout) self.final_linear = nn.Linear(self.inner_hidden_size, hidden_size, bias=has_bias) + # layer-wise attention scaling + if layer_number is not None: + self.layer_number = max(1, layer_number) + self.norm_factor = math.sqrt(self.per_head_size) * self.layer_number + else: + self.layer_number = None def forward(self, key, value, query, mask, position_bias=None, has_residual_attention=False, prev_attn=None, - freqs_cis=None): + freqs_cis=None, alibi=None): """ Args: key: [batch_size x seq_length x hidden_size] @@ -61,28 +88,53 @@ def shape(x): def unshape(x): return x. \ - transpose(1, 2). \ - contiguous(). \ - view(batch_size, seq_length, self.inner_hidden_size) + transpose(1, 2). \ + contiguous(). \ + view(batch_size, seq_length, self.inner_hidden_size) + + query, key, value = [linear_layer(x) for linear_layer, x in zip(self.linear_layers, [query, key, value])] + + query = query.view(batch_size, seq_length, heads_num, per_head_size) + key = key.view(batch_size, seq_length, self.local_kv_heads_num, per_head_size) + value = value.view(batch_size, seq_length, self.local_kv_heads_num, per_head_size) + + query = query.transpose(1, 2) + key = repeat_kv(key, self.repeat_num).transpose(1, 2) + value = repeat_kv(value, self.repeat_num).transpose(1, 2) + - query, key, value = [l(x). \ - view(batch_size, -1, heads_num, per_head_size). \ - transpose(1, 2) \ - for l, x in zip(self.linear_layers, (query, key, value)) - ] if freqs_cis is not None: query, key = apply_rotary_emb(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis) + scores = torch.matmul(query, key.transpose(-2, -1)) + if position_bias is not None: scores = scores + position_bias + if self.with_scale: - scores = scores / math.sqrt(float(per_head_size)) + if self.layer_number is not None: + scores = scores * (1.0 / self.norm_factor) + else: + scores = scores / math.sqrt(float(per_head_size)) + if alibi is not None: + scores = scores.reshape((-1, scores.shape[-2], scores.shape[-1])) + scores += (1.0 / self.layer_number) * alibi + scores = scores.view(-1, heads_num, scores.shape[-2], scores.shape[-1]) + scores = scores + mask.type_as(scores) + + # scaled softmax + if self.layer_number is not None: + scores = (scores * self.layer_number) + mask + scores = torch.max(scores, torch.tensor(-10000)) + prev_attn_out = None + if has_residual_attention: if prev_attn is not None: scores += prev_attn prev_attn_out = scores + probs = nn.Softmax(dim=-1)(scores) probs = self.dropout(probs) output = unshape(torch.matmul(probs, value)) diff --git a/tencentpretrain/layers/transformer.py b/tencentpretrain/layers/transformer.py index f859ec0..be9f2cf 100755 --- a/tencentpretrain/layers/transformer.py +++ b/tencentpretrain/layers/transformer.py @@ -1,16 +1,13 @@ import torch.nn as nn -from tencentpretrain.layers.layer_norm import * -from tencentpretrain.layers.position_ffn import PositionwiseFeedForward, GatedFeedForward from tencentpretrain.layers.multi_headed_attn import MultiHeadedAttention -from tencentpretrain.layers.relative_position_embedding import RelativePositionEmbedding - +from tencentpretrain.layers import * class TransformerLayer(nn.Module): """ Transformer layer mainly consists of two parts: multi-headed self-attention and feed forward layer. """ - def __init__(self, args): + def __init__(self, args, layer_number=None): super(TransformerLayer, self).__init__() self.layernorm_positioning = args.layernorm_positioning @@ -20,6 +17,11 @@ def __init__(self, args): else: attention_head_size = args.hidden_size // args.heads_num + if hasattr(args, "local_kv_heads_num"): + local_kv_heads_num = args.local_kv_heads_num + else: + local_kv_heads_num = args.heads_num + has_bias = bool(1 - args.remove_transformer_bias) with_scale = bool(1 - args.remove_attention_scale) @@ -29,33 +31,22 @@ def __init__(self, args): lora_params = args.lora_params self.self_attn = MultiHeadedAttention( - args.hidden_size, args.heads_num, attention_head_size, args.dropout, has_bias=has_bias, - with_scale = with_scale, lora_params=lora_params + args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias, + with_scale = with_scale, lora_params=lora_params, layer_number=layer_number ) self.dropout_1 = nn.Dropout(args.dropout) # Feed forward layer. - if args.feed_forward == "gated": - self.feed_forward = GatedFeedForward( - args.hidden_size, args.feedforward_size, args.hidden_act, has_bias - ) - else: - self.feed_forward = PositionwiseFeedForward( - args.hidden_size, args.feedforward_size, args.hidden_act, has_bias - ) + self.feed_forward = str2feedforward[args.feed_forward]( + args.hidden_size, args.feedforward_size, args.hidden_act, has_bias + ) self.dropout_2 = nn.Dropout(args.dropout) - if args.layernorm == "t5": - self.layer_norm_1 = T5LayerNorm(args.hidden_size) - self.layer_norm_2 = T5LayerNorm(args.hidden_size) - elif args.layernorm == "rms": - self.layer_norm_1 = RMSNorm(args.hidden_size) - self.layer_norm_2 = RMSNorm(args.hidden_size) - else: - self.layer_norm_1 = LayerNorm(args.hidden_size) - self.layer_norm_2 = LayerNorm(args.hidden_size) + self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) + self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) - def forward(self, hidden, mask, position_bias=None, has_residual_attention=False, prev_attn=None, freqs_cis=None): + def forward(self, hidden, mask, position_bias=None, has_residual_attention=False, + prev_attn=None, freqs_cis=None, alibi=None): """ Args: hidden: [batch_size x seq_length x emb_size] @@ -66,14 +57,16 @@ def forward(self, hidden, mask, position_bias=None, has_residual_attention=False """ if self.layernorm_positioning == "post": - inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, has_residual_attention, prev_attn, freqs_cis) + inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, has_residual_attention, + prev_attn, freqs_cis, alibi) inter = self.dropout_1(inter) inter = self.layer_norm_1(inter + hidden) output = self.dropout_2(self.feed_forward(inter)) output = self.layer_norm_2(output + inter) else: inter = self.layer_norm_1(hidden) - inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, has_residual_attention, prev_attn, freqs_cis) + inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, has_residual_attention, + prev_attn, freqs_cis, alibi) inter = self.dropout_1(inter) hidden = hidden + inter output = self.layer_norm_2(hidden) @@ -92,6 +85,11 @@ def __init__(self, args): else: attention_head_size = args.hidden_size // args.heads_num + if hasattr(args, "local_kv_heads_num"): + local_kv_heads_num = args.local_kv_heads_num + else: + local_kv_heads_num = args.heads_num + has_bias = bool(1 - args.remove_transformer_bias) with_scale = bool(1 - args.remove_attention_scale) @@ -101,38 +99,28 @@ def __init__(self, args): lora_params = args.lora_params self.self_attn = MultiHeadedAttention( - args.hidden_size, args.heads_num, attention_head_size, args.dropout, has_bias=has_bias, + args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias, with_scale=with_scale, lora_params=lora_params ) self.dropout_1 = nn.Dropout(args.dropout) # Multi-headed context-attention. self.context_attn = MultiHeadedAttention( - args.hidden_size, args.heads_num, attention_head_size, args.dropout, has_bias=has_bias, + args.hidden_size, args.heads_num, attention_head_size, local_kv_heads_num, args.dropout, has_bias=has_bias, with_scale=with_scale, lora_params=lora_params ) self.dropout_2 = nn.Dropout(args.dropout) # Feed forward layer. - if args.feed_forward == "gated": - self.feed_forward = GatedFeedForward( - args.hidden_size, args.feedforward_size, args.hidden_act, has_bias - ) - else: - self.feed_forward = PositionwiseFeedForward( - args.hidden_size, args.feedforward_size, args.hidden_act, has_bias - ) + self.feed_forward = str2feedforward[args.feed_forward]( + args.hidden_size, args.feedforward_size, args.hidden_act, has_bias + ) self.dropout_3 = nn.Dropout(args.dropout) # Layer Normalization - if args.layernorm == "t5": - self.layer_norm_1 = T5LayerNorm(args.hidden_size) - self.layer_norm_2 = T5LayerNorm(args.hidden_size) - self.layer_norm_3 = T5LayerNorm(args.hidden_size) - else: - self.layer_norm_1 = LayerNorm(args.hidden_size) - self.layer_norm_2 = LayerNorm(args.hidden_size) - self.layer_norm_3 = LayerNorm(args.hidden_size) + self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) + self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) + self.layer_norm_3 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) def forward(self, hidden, encoder_hidden, mask_decoder, mask_encoder, self_position_bias=None, context_position_bias=None): """ diff --git a/tencentpretrain/opts.py b/tencentpretrain/opts.py index c4b8074..eec3b73 100755 --- a/tencentpretrain/opts.py +++ b/tencentpretrain/opts.py @@ -30,8 +30,11 @@ def model_opts(parser): help="Remove attention scale.") parser.add_argument("--remove_transformer_bias", action="store_true", help="Remove bias on transformer layers.") - parser.add_argument("--layernorm", choices=["normal", "t5","rms"], default="normal", + parser.add_argument("--layernorm", choices=["normal", "t5", "rms"], default="normal", + help="Layernorm type.") + parser.add_argument("--layernorm_eps", type=float, default=1e-6, + help="Layernorm eps.") parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.") parser.add_argument("--parameter_sharing", action="store_true", help="Parameter sharing.") parser.add_argument("--has_residual_attention", action="store_true", help="Add residual attention.") @@ -45,6 +48,10 @@ def model_opts(parser): help="Pooling type.") parser.add_argument("--prefix_lm_loss", action="store_true", help="Only compute output loss when SFT.") + parser.add_argument("--alibi_position_embedding", action="store_true", + help="whether use alibi position embedding.") + parser.add_argument("--layer_number_scale", action="store_true", + help="whether use layer number scaling.") vision_opts(parser) audio_opts(parser) @@ -97,11 +104,6 @@ def optimization_opts(parser): help="Warm up value.") parser.add_argument("--decay", type=float, default=0.5, help="decay value.") - parser.add_argument("--fp16", action='store_true', - help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit.") - parser.add_argument("--fp16_opt_level", choices=["O0", "O1", "O2", "O3" ], default='O1', - help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." - "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument("--optimizer", choices=["adamw", "adafactor"], default="adamw", help="Optimizer type.") @@ -173,7 +175,8 @@ def infer_opts(parser): def tokenizer_opts(parser): - parser.add_argument("--tokenizer", choices=["bert", "bpe", "char", "space", "xlmroberta", "image", "text_image", "virtual"], default="bert", + parser.add_argument("--tokenizer", choices=["bert", "bpe", "char", "space", "xlmroberta", "image", "text_image", + "virtual", "hfpretrained"], default="bert", help="Specify the tokenizer." "Original Google BERT uses bert tokenizer." "Char tokenizer segments sentences into characters." @@ -218,6 +221,8 @@ def deepspeed_opts(parser): help="Checkpoint activation to allow for training with larger models, sequences, and batch sizes.") parser.add_argument("--deepspeed_checkpoint_layers_num", type=int, default=1, help="chunk size (number of layers) for checkpointing.") + parser.add_argument("--resume_from_checkpoint", type=str, default=None, + help="resume form deepspeed format checkpoint (only support zero1&2).") parser.add_argument("--local_rank", type=int, required=False) diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index 298eb79..a959b3b 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -33,7 +33,7 @@ def train_and_validate(args): model_for_training = build_model(args) # Load or initialize parameters. - if args.pretrained_model_path is not None: + if args.pretrained_model_path is not None and args.resume_from_checkpoint is None: # Initialize with pretrained model. if args.deepspeed and args.enable_zero3: if os.path.isdir(args.pretrained_model_path): @@ -132,11 +132,7 @@ def train(self, args, gpu_id, rank, loader, model, optimizer, scheduler): if args.deepspeed: model.backward(loss) else: - if args.fp16: - with args.amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + loss.backward() if self.current_step % self.accumulation_steps == 0: if args.deepspeed: @@ -607,7 +603,7 @@ def worker(proc_id, gpu_ranks, args, model_for_training, model_for_dataloader=No ] if args.optimizer in ["adamw"]: - if args.deepspeed and deepspeed.__version__ > "0.5.8": + if args.deepspeed: custom_optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False) else: custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False) @@ -631,6 +627,12 @@ def worker(proc_id, gpu_ranks, args, model_for_training, model_for_dataloader=No lr_scheduler=custom_scheduler, mpu=None, dist_init_required=False) + if args.resume_from_checkpoint is not None: + load_path, _ = model_for_training.load_checkpoint( + args.resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True + ) + if load_path is None: + raise ValueError(f"[deepspeed] failed to resume from checkpoint {args.resume_from_checkpoint}") else: if gpu_id is not None: model_for_training.cuda(gpu_id) @@ -638,13 +640,6 @@ def worker(proc_id, gpu_ranks, args, model_for_training, model_for_dataloader=No model_for_dataloader.cuda(gpu_id) optimizer = custom_optimizer scheduler = custom_scheduler - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model_for_training, optimizer = amp.initialize(model_for_training, optimizer, opt_level=args.fp16_opt_level) - args.amp = amp if args.dist_train: # Initialize multiprocessing distributed training environment. diff --git a/tencentpretrain/utils/__init__.py b/tencentpretrain/utils/__init__.py index e7a0e3c..aa29517 100644 --- a/tencentpretrain/utils/__init__.py +++ b/tencentpretrain/utils/__init__.py @@ -7,7 +7,7 @@ str2tokenizer = {"char": CharTokenizer, "space": SpaceTokenizer, "bert": BertTokenizer, "bpe": BPETokenizer, "xlmroberta": XLMRobertaTokenizer, "image": ImageTokenizer, - "text_image": TextImageTokenizer, "virtual": VirtualTokenizer} + "text_image": TextImageTokenizer, "virtual": VirtualTokenizer, "hfpretrained": HFPreTrainedTokenizer} str2dataset = {"bert": BertDataset, "lm": LmDataset, "mlm": MlmDataset, "bilm": BilmDataset, "albert": AlbertDataset, "mt": MtDataset, "t5": T5Dataset, "gsg": GsgDataset, "bart": BartDataset, diff --git a/tencentpretrain/utils/alibi.py b/tencentpretrain/utils/alibi.py new file mode 100644 index 0000000..842c8d4 --- /dev/null +++ b/tencentpretrain/utils/alibi.py @@ -0,0 +1,46 @@ +import torch +import math + + +def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + Args: + Returns tensor shaped (batch_size * n_head, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + n_head (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + device (`torch.device`, *optional*, default=`torch.device('cpu')`): + device of the output alibi tensor + """ + closest_power_of_2 = 2 ** math.floor(math.log2(n_head)) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != n_head: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32 + ) + num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + # batch_size = 1, n_head = n_head, query_length + + arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask[:, None] + alibi = slopes.unsqueeze(-1) * arange_tensor + alibi = alibi * attention_mask[:, None] + return alibi.reshape(alibi.shape[0] * n_head, 1, -1).to(dtype) \ No newline at end of file diff --git a/tencentpretrain/utils/tokenizers.py b/tencentpretrain/utils/tokenizers.py index 2f13a3d..6ee3512 100644 --- a/tencentpretrain/utils/tokenizers.py +++ b/tencentpretrain/utils/tokenizers.py @@ -602,3 +602,23 @@ def __init__(self, args, is_src=True): self.vocab_bias = len(self.vocab) for i in range(args.image_tokenizer["image_vocab_size"]): self.vocab[i + self.vocab_bias] = str(i) + + +class HFPreTrainedTokenizer(Tokenizer): + def __init__(self, args, is_src=True): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(args.vocab_path) + self.sp_model = None + self.vocab = self.tokenizer.vocab + + def tokenize(self, text): + return self.tokenizer.tokenize(text) + + def convert_tokens_to_ids(self, tokens): + return self.tokenizer.convert_tokens_to_ids(tokens) + + def convert_ids_to_tokens(self, ids): + return self.tokenizer.convert_ids_to_tokens(ids) + + def decode(self, ids): + return self.tokenizer.decode(ids)