diff --git a/.gitignore b/.gitignore index 9b4fde65..2d0d03c5 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,6 @@ pic2.py # Pyre type checker .pyre/ +wandb/ +*.txt +result/ \ No newline at end of file diff --git a/FlagEmbedding/baai_general_embedding/finetune/data.py b/FlagEmbedding/baai_general_embedding/finetune/data.py index 387c1482..57381e7f 100644 --- a/FlagEmbedding/baai_general_embedding/finetune/data.py +++ b/FlagEmbedding/baai_general_embedding/finetune/data.py @@ -8,7 +8,7 @@ from torch.utils.data import Dataset from transformers import DataCollatorWithPadding, PreTrainedTokenizer -from .arguments import DataArguments +from arguments import DataArguments class TrainDatasetForEmbedding(Dataset): diff --git a/FlagEmbedding/baai_general_embedding/finetune/hn_mine.py b/FlagEmbedding/baai_general_embedding/finetune/hn_mine.py index dfa82a75..8d9a5a94 100644 --- a/FlagEmbedding/baai_general_embedding/finetune/hn_mine.py +++ b/FlagEmbedding/baai_general_embedding/finetune/hn_mine.py @@ -59,6 +59,7 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n corpus = [] queries = [] train_data = [] + # input_file is jsonl, jsonl也是由 query,pos,neg三元组组成,并且 pos 和 neg 都全部放入 corpus 中, query放入 querys 中 for line in open(input_file): line = json.loads(line.strip()) train_data.append(line) @@ -67,6 +68,7 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n corpus.extend(line['neg']) queries.append(line['query']) + # candidate pool和 corpus 库是二选一的 if candidate_pool is not None: if not isinstance(candidate_pool, list): candidate_pool = get_corpus(candidate_pool) diff --git a/FlagEmbedding/baai_general_embedding/finetune/modeling.py b/FlagEmbedding/baai_general_embedding/finetune/modeling.py index c5d2935f..70600e04 100644 --- a/FlagEmbedding/baai_general_embedding/finetune/modeling.py +++ b/FlagEmbedding/baai_general_embedding/finetune/modeling.py @@ -60,19 +60,21 @@ def __init__(self, def gradient_checkpointing_enable(self, **kwargs): self.model.gradient_checkpointing_enable(**kwargs) - def sentence_embedding(self, hidden_state, mask): + def sentence_embedding(self, output, mask): if self.sentence_pooling_method == 'mean': - s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1) + s = torch.sum(output.last_hidden_state * mask.unsqueeze(-1).float(), dim=1) d = mask.sum(axis=1, keepdim=True).float() return s / d elif self.sentence_pooling_method == 'cls': - return hidden_state[:, 0] + return output.last_hidden_state[:, 0] + elif self.sentence_pooling_method == 'cls_after_pooler': + return output.pooler_output def encode(self, features): if features is None: return None psg_out = self.model(**features, return_dict=True) - p_reps = self.sentence_embedding(psg_out.last_hidden_state, features['attention_mask']) + p_reps = self.sentence_embedding(psg_out, features['attention_mask']) if self.normlized: p_reps = torch.nn.functional.normalize(p_reps, dim=-1) return p_reps.contiguous() diff --git a/FlagEmbedding/baai_general_embedding/finetune/run.py b/FlagEmbedding/baai_general_embedding/finetune/run.py index ff48281d..311e0239 100644 --- a/FlagEmbedding/baai_general_embedding/finetune/run.py +++ b/FlagEmbedding/baai_general_embedding/finetune/run.py @@ -1,6 +1,7 @@ import logging import os from pathlib import Path +os.environ["WANDB_DISABLED"]="true" from transformers import AutoConfig, AutoTokenizer from transformers import ( @@ -8,15 +9,26 @@ set_seed, ) -from .arguments import ModelArguments, DataArguments, \ +from arguments import ModelArguments, DataArguments, \ RetrieverTrainingArguments as TrainingArguments -from .data import TrainDatasetForEmbedding, EmbedCollator -from .modeling import BiEncoderModel -from .trainer import BiTrainer +from data import TrainDatasetForEmbedding, EmbedCollator +from modeling import BiEncoderModel +from trainer import BiTrainer +import sys +import transformers +sys.path.append("/opt/tiger/FlagEmbedding") +from utils import get_complete_last_checkpoint logger = logging.getLogger(__name__) - +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) + def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() @@ -24,15 +36,20 @@ def main(): data_args: DataArguments training_args: TrainingArguments - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." - ) + # check and load checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_complete_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + logger.info( + f"Output directory ({training_args.output_dir}) already exists and is empty." + "Train from scratch" + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -75,6 +92,14 @@ def main(): temperature=training_args.temperature, use_inbatch_neg=training_args.use_inbatch_neg, ) + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + logger.info(f"train start from {training_args.resume_from_checkpoint}") + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + logger.info(f"train start from {last_checkpoint}") + checkpoint = last_checkpoint if training_args.fix_position_embedding: for k, v in model.named_parameters(): @@ -99,12 +124,18 @@ def main(): Path(training_args.output_dir).mkdir(parents=True, exist_ok=True) # Training - trainer.train() - trainer.save_model() + trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_state() + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + # try: + # trainer.train(resume_from_checkpoint=checkpoint) + # except: + # trainer.train() + # trainer.save_model() # For convenience, we also re-save the tokenizer to the same directory, # so that you can share your model easily on huggingface.co/models =) - if trainer.is_world_process_zero(): - tokenizer.save_pretrained(training_args.output_dir) + # if trainer.is_world_process_zero(): + # tokenizer.save_pretrained(training_args.output_dir) if __name__ == "__main__": diff --git a/FlagEmbedding/reranker/arguments.py b/FlagEmbedding/reranker/arguments.py index ae5d639a..f89554fe 100644 --- a/FlagEmbedding/reranker/arguments.py +++ b/FlagEmbedding/reranker/arguments.py @@ -21,6 +21,7 @@ class ModelArguments: cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} ) + model_type: str = field(default="CrossEncoder") @dataclass diff --git a/FlagEmbedding/reranker/data.py b/FlagEmbedding/reranker/data.py index 5ddf0972..daf52b44 100644 --- a/FlagEmbedding/reranker/data.py +++ b/FlagEmbedding/reranker/data.py @@ -10,7 +10,7 @@ from transformers import DataCollatorWithPadding from transformers import PreTrainedTokenizer, BatchEncoding -from .arguments import DataArguments +from arguments import DataArguments class TrainDatasetForCE(Dataset): @@ -62,6 +62,29 @@ def __getitem__(self, item) -> List[BatchEncoding]: return batch_data +class TrainDatasetForCL(TrainDatasetForCE): + def create_one_example(self, input): + item = self.tokenizer( + input, + truncation=True, + max_length=self.args.max_len, + padding=False, + ) + return item + + def __getitem__(self, item) -> List[BatchEncoding]: + query = self.dataset[item]['query'] + pos = random.choice(self.dataset[item]['pos']) + if len(self.dataset[item]['neg']) < self.args.train_group_size - 1: + num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg'])) + negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1) + else: + negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1) + batch_data = [] + batch_data.append(self.create_one_example(query)) + batch_data.append(self.create_one_example(pos)) + for neg in negs: batch_data.append(self.create_one_example(neg)) + return batch_data @dataclass diff --git a/FlagEmbedding/reranker/embedding_proj_run.py b/FlagEmbedding/reranker/embedding_proj_run.py new file mode 100644 index 00000000..306ddf97 --- /dev/null +++ b/FlagEmbedding/reranker/embedding_proj_run.py @@ -0,0 +1,125 @@ +import logging +import os +from pathlib import Path + +from transformers import AutoConfig, AutoTokenizer, TrainingArguments +from transformers import ( + HfArgumentParser, + set_seed, +) +from arguments import ModelArguments, DataArguments +from data import TrainDatasetForCE, GroupCollator +from modeling import CLProjEncoder +from trainer import CETrainer + +logger = logging.getLogger(__name__) +from pprint import pprint as pp +import sys +sys.path.append("/opt/tiger/FlagEmbedding") +from FlagEmbedding.reranker.data import TrainDatasetForCL +from utils import get_complete_last_checkpoint +import transformers +import os +os.environ["WANDB_DISABLED"]="true" + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) + +def main(): + parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: ModelArguments + data_args: DataArguments + training_args: TrainingArguments + + # for args in (model_args, data_args, training_args): pp(args) + + # check and load checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_complete_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + logger.info( + f"Output directory ({training_args.output_dir}) already exists and is empty." + "Train from scratch" + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + training_args.fp16, + ) + logger.info("Training/evaluation parameters %s", training_args) + logger.info("Model parameters %s", model_args) + logger.info("Data parameters %s", data_args) + + set_seed(training_args.seed) + + num_labels = 1 + + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=False, + ) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + cache_dir=model_args.cache_dir, + trust_remote_code=True + ) + _model_class = CLProjEncoder + + model = _model_class.from_pretrained( + model_args, data_args, training_args, + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + trust_remote_code=True + ) + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + logger.info(f"train start from {training_args.resume_from_checkpoint}") + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + logger.info(f"train start from {last_checkpoint}") + checkpoint = last_checkpoint + + train_dataset = TrainDatasetForCL(data_args, tokenizer=tokenizer) + _trainer_class = CETrainer + + trainer = _trainer_class( + model=model, + args=training_args, + train_dataset=train_dataset, + data_collator=GroupCollator(tokenizer), #这里依旧是拍平 + tokenizer=tokenizer + ) + trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_state() + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/reranker/embedding_reranker.ipynb b/FlagEmbedding/reranker/embedding_reranker.ipynb new file mode 100644 index 00000000..b9c6f0d7 --- /dev/null +++ b/FlagEmbedding/reranker/embedding_reranker.ipynb @@ -0,0 +1,613 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from transformers import AutoModelForSequenceClassification, PreTrainedModel, TrainingArguments, AutoTokenizer\n", + "from transformers.modeling_outputs import SequenceClassifierOutput" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "model = AutoModelForSequenceClassification.from_pretrained(\"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/reranker_group30_batch2_v100\", torch_dtype=torch.float16, device_map=\"auto\")\n", + "tokenizer = AutoTokenizer.from_pretrained(\"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/reranker_group30_batch2_v100\")" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "import torch\n", + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "# 假设 model 是已经加载的模型\n", + "# model = AutoModelForSequenceClassification.from_pretrained(...)\n", + "\n", + "# 假设 group_size 是你的模型处理的样本数量\n", + "group_size = 15 # 根据实际情况设置\n", + "batch_size = 2\n", + "\n", + "# 假设你已经有一个文本序列和对应的标签\n", + "text = \"这是一个示例文本。\"\n", + "label = 1 # 假设标签是1\n", + "\n", + "# 使用模型的分词器对文本进行编码\n", + "encoded_input = tokenizer(text, padding='max_length', truncation=True, max_length=512, return_tensors='pt')\n", + "\n", + "# 调整编码后的输入以匹配 group_size\n", + "input_ids = encoded_input['input_ids'].repeat_interleave(batch_size*group_size, dim=0).cuda()\n", + "attention_mask = encoded_input['attention_mask'].repeat_interleave(batch_size*group_size, dim=0).cuda()\n", + "\n", + "print(type(input_ids))\n", + "\n", + "# 创建 batch 字典\n", + "batch = {\n", + " 'input_ids': input_ids,\n", + " 'attention_mask': attention_mask,\n", + "}\n", + "\n", + "labels = torch.tensor([label]*batch_size, dtype=torch.long).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "def get_embedding(input_ids, attention_mask, model=model, tokenizer=tokenizer):\n", + " hidden_state = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1].cpu()\n", + " attention_mask = attention_mask.cpu()\n", + " seq_lengths = attention_mask.sum(dim=1)\n", + " embeddings = []\n", + " for seq_len, seq_emb in zip(seq_lengths, hidden_state):\n", + " valid_emb = seq_emb[:seq_len]\n", + " embeddings.append(torch.mean(valid_emb, dim=0))\n", + "\n", + " embedding = torch.stack(embeddings)\n", + " return embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "def forward(batch):\n", + " cross_entropy = nn.CrossEntropyLoss(reduction='mean')\n", + " embeddings = get_embedding(**batch)\n", + " loss = batchloss(embeddings)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-0.2715, -0.0060, -0.6865, ..., -0.9995, -0.5493, 0.4402],\n", + " [-0.2715, -0.0060, -0.6865, ..., -0.9995, -0.5493, 0.4402],\n", + " [-0.2715, -0.0060, -0.6865, ..., -0.9995, -0.5493, 0.4402],\n", + " ...,\n", + " [-0.2715, -0.0060, -0.6865, ..., -0.9995, -0.5493, 0.4402],\n", + " [-0.2715, -0.0060, -0.6865, ..., -0.9995, -0.5493, 0.4402],\n", + " [-0.2715, -0.0060, -0.6865, ..., -0.9995, -0.5493, 0.4402]],\n", + " dtype=torch.float16, grad_fn=)\n", + "torch.Size([30, 768])\n" + ] + } + ], + "source": [ + "results = get_embedding(**batch)\n", + "print(results)\n", + "print(results.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2, 15, 768])\n" + ] + } + ], + "source": [ + "pred = results.view(batch_size, group_size, -1)\n", + "print(pred.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "def infoNCELoss(anchor, positive, negatives, temperature=1):\n", + " # 计算所有样本的相似度\n", + " pos_similarity = F.cosine_similarity(anchor, positive, dim=-1)\n", + " # 将anchor重复到与负样本相同数量的维度,以便计算\n", + " neg_similarity = F.cosine_similarity(anchor, negatives, dim=-1)\n", + " # 合并正样本和负样本的相似度\n", + " all_similarity = torch.cat([pos_similarity, neg_similarity])\n", + " # 应用温度缩放\n", + " all_similarity /= temperature\n", + " # 计算InfoNCE损失\n", + " loss = - torch.log(torch.exp(pos_similarity)/torch.sum(torch.exp(all_similarity)))\n", + " return loss.mean()\n", + "\n", + "def batchloss(embeddings):\n", + " # 遍历每个batch计算损失\n", + " losses = []\n", + " for i in range(embeddings.size(0)):\n", + " # anchor embeddings\n", + " anchor = embeddings[i, 0].unsqueeze(0) # [1, 768]\n", + " # positive embeddings\n", + " positive = embeddings[i, 1].unsqueeze(0) # [1, 768]\n", + " # 除了anchor和positive之外的所有embeddings作为负样本\n", + " negatives = embeddings[i, 2:] # [13, 768]\n", + " # 计算当前batch的InfoNCE损失\n", + " loss = infoNCELoss(anchor, positive, negatives)\n", + " losses.append(loss)\n", + " # 计算整个batch的平均损失\n", + " batch_loss = torch.mean(torch.stack(losses))\n", + " return batch_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(2.6431)\n" + ] + } + ], + "source": [ + "# 假设 embeddings 是一个形状为 [batch, group_size, embedding_len] 的张量\n", + "embeddings = torch.randn(2, 15, 768) # 示例数据\n", + "print(batchloss(embeddings))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 去掉模型的分类头" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "XLMRobertaModel(\n", + " (embeddings): XLMRobertaEmbeddings(\n", + " (word_embeddings): Embedding(250002, 768, padding_idx=1)\n", + " (position_embeddings): Embedding(514, 768, padding_idx=1)\n", + " (token_type_embeddings): Embedding(1, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): XLMRobertaEncoder(\n", + " (layer): ModuleList(\n", + " (0-11): 12 x XLMRobertaLayer(\n", + " (attention): XLMRobertaAttention(\n", + " (self): XLMRobertaSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): XLMRobertaSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): XLMRobertaIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " (intermediate_act_fn): GELUActivation()\n", + " )\n", + " (output): XLMRobertaOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): XLMRobertaPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "from transformers import XLMRobertaForSequenceClassification, AutoModel, AutoTokenizer\n", + "import torch\n", + "model = AutoModel.from_pretrained('/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/xlmr/models--FacebookAI--xlm-roberta-base/snapshots/e73636d4f797dec63c3081bb6ed5c7b0bb3f2089/', torch_dtype=torch.float16)\n", + "tokenizer = AutoTokenizer.from_pretrained(\"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/xlmr/models--FacebookAI--xlm-roberta-base/snapshots/e73636d4f797dec63c3081bb6ed5c7b0bb3f2089/\")\n", + "# print(type(model.modules()))\n", + "print(model)\n", + "# print(model.roberta)\n", + "# print(model.classifier)" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "XLMRobertaForSequenceClassification(\n", + " (roberta): XLMRobertaModel(\n", + " (embeddings): XLMRobertaEmbeddings(\n", + " (word_embeddings): Embedding(250002, 768, padding_idx=1)\n", + " (position_embeddings): Embedding(514, 768, padding_idx=1)\n", + " (token_type_embeddings): Embedding(1, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): XLMRobertaEncoder(\n", + " (layer): ModuleList(\n", + " (0-11): 12 x XLMRobertaLayer(\n", + " (attention): XLMRobertaAttention(\n", + " (self): XLMRobertaSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): XLMRobertaSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): XLMRobertaIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " (intermediate_act_fn): GELUActivation()\n", + " )\n", + " (output): XLMRobertaOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "del model.classifier\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from sentence_transformers import SentenceTransformer, models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 看看 Data 构造" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"hello\"\n", + "pos = \"hello\"\n", + "negs = [\"hello\",\"hello\"]\n", + "\n", + "batch_data = tokenizer([query]+[pos]+negs, padding=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 看看开源的sentence encoder" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cb32cfe5d0004229bc9805b89f8aada3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer_config.json: 0%| | 0.00/394 [00:00\n", + "n192-024-092:113099:113099 [0] NCCL INFO NET/Plugin : dlerror=libnccl-net.so: cannot open shared object file: No such file or directory No plugin found (libnccl-net.so), using internal implementation\n", + "NCCL version 2.20.5+cuda12.4\n", + "n192-024-092:113099:271766 [0] NCCL INFO NCCL_IB_DISABLE set by environment to 0.\n", + "n192-024-092:113099:271766 [0] NCCL INFO NCCL_SOCKET_FAMILY set by environment to AF_INET6\n", + "n192-024-092:113099:271766 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0\n", + "n192-024-092:113099:271766 [0] NCCL INFO NCCL_IB_HCA set to mlx5_2:1\n", + "n192-024-092:113099:271766 [0] NCCL INFO NET/IB : Using [0]mlx5_2:1/RoCE [RO]; OOB eth0:fdbd:dc61:7:34::92<0>\n", + "n192-024-092:113099:271766 [0] NCCL INFO Using non-device net plugin version 0\n", + "n192-024-092:113099:271766 [0] NCCL INFO Using network IB\n", + "n192-024-092:113099:271767 [1] NCCL INFO Using non-device net plugin version 0\n", + "n192-024-092:113099:271767 [1] NCCL INFO Using network IB\n", + "n192-024-092:113099:271766 [0] NCCL INFO comm 0xb205b610 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 1b000 commId 0xa0b4958593d819c3 - Init START\n", + "n192-024-092:113099:271767 [1] NCCL INFO comm 0xb205e980 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 1c000 commId 0xa0b4958593d819c3 - Init START\n", + "n192-024-092:113099:271767 [1] NCCL INFO Setting affinity for GPU 1 to ff,ffff0000,00ffffff\n", + "n192-024-092:113099:271766 [0] NCCL INFO Setting affinity for GPU 0 to ff,ffff0000,00ffffff\n", + "n192-024-092:113099:271766 [0] NCCL INFO comm 0xb205b610 rank 0 nRanks 2 nNodes 1 localRanks 2 localRank 0 MNNVL 0\n", + "n192-024-092:113099:271766 [0] NCCL INFO Channel 00/02 : 0 1\n", + "n192-024-092:113099:271767 [1] NCCL INFO comm 0xb205e980 rank 1 nRanks 2 nNodes 1 localRanks 2 localRank 1 MNNVL 0\n", + "n192-024-092:113099:271766 [0] NCCL INFO Channel 01/02 : 0 1\n", + "n192-024-092:113099:271766 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1\n", + "n192-024-092:113099:271766 [0] NCCL INFO P2P Chunksize set to 524288\n", + "n192-024-092:113099:271767 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] -1/-1/-1->1->0\n", + "n192-024-092:113099:271767 [1] NCCL INFO P2P Chunksize set to 524288\n", + "n192-024-092:113099:271766 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/direct pointer\n", + "n192-024-092:113099:271767 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/direct pointer\n", + "n192-024-092:113099:271766 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/direct pointer\n", + "n192-024-092:113099:271767 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/direct pointer\n", + "n192-024-092:113099:271766 [0] NCCL INFO Connected all rings\n", + "n192-024-092:113099:271766 [0] NCCL INFO Connected all trees\n", + "n192-024-092:113099:271767 [1] NCCL INFO Connected all rings\n", + "n192-024-092:113099:271767 [1] NCCL INFO Connected all trees\n", + "n192-024-092:113099:271767 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512\n", + "n192-024-092:113099:271767 [1] NCCL INFO 2 coll channels, 0 collnet channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer\n", + "n192-024-092:113099:271766 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512\n", + "n192-024-092:113099:271766 [0] NCCL INFO 2 coll channels, 0 collnet channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer\n", + "n192-024-092:113099:271767 [1] NCCL INFO comm 0xb205e980 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 1c000 commId 0xa0b4958593d819c3 - Init COMPLETE\n", + "n192-024-092:113099:271766 [0] NCCL INFO comm 0xb205b610 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 1b000 commId 0xa0b4958593d819c3 - Init COMPLETE\n", + "[[0.855 0.852 ]\n", + " [0.874 0.8555]]\n", + "n192-024-092:113099:271788 [0] NCCL INFO Using non-device net plugin version 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Using network IB\n", + "n192-024-092:113099:271788 [0] NCCL INFO comm 0xb80a5d70 rank 0 nranks 1 cudaDev 0 nvmlDev 0 busId 1b000 commId 0xdd4574414117f6ad - Init START\n", + "n192-024-092:113099:271788 [0] NCCL INFO Setting affinity for GPU 0 to ff,ffff0000,00ffffff\n", + "n192-024-092:113099:271788 [0] NCCL INFO comm 0xb80a5d70 rank 0 nRanks 1 nNodes 1 localRanks 1 localRank 0 MNNVL 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 00/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 01/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 02/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 03/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 04/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 05/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 06/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 07/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 08/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 09/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 10/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 11/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 12/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 13/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 14/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 15/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 16/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 17/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 18/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 19/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 20/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 21/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 22/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 23/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 24/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 25/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 26/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 27/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 28/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 29/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 30/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Channel 31/32 : 0\n", + "n192-024-092:113099:271788 [0] NCCL INFO Trees [0] -1/-1/-1->0->-1 [1] -1/-1/-1->0->-1 [2] -1/-1/-1->0->-1 [3] -1/-1/-1->0->-1 [4] -1/-1/-1->0->-1 [5] -1/-1/-1->0->-1 [6] -1/-1/-1->0->-1 [7] -1/-1/-1->0->-1 [8] -1/-1/-1->0->-1 [9] -1/-1/-1->0->-1 [10] -1/-1/-1->0->-1 [11] -1/-1/-1->0->-1 [12] -1/-1/-1->0->-1 [13] -1/-1/-1->0->-1 [14] -1/-1/-1->0->-1 [15] -1/-1/-1->0->-1 [16] -1/-1/-1->0->-1 [17] -1/-1/-1->0->-1 [18] -1/-1/-1->0->-1 [19] -1/-1/-1->0->-1 [20] -1/-1/-1->0->-1 [21] -1/-1/-1->0->-1 [22] -1/-1/-1->0->-1 [23] -1/-1/-1->0->-1 [24] -1/-1/-1->0->-1 [25] -1/-1/-1->0->-1 [26] -1/-1/-1->0->-1 [27] -1/-1/-1->0->-1 [28] -1/-1/-1->0->-1 [29] -1/-1/-1->0->-1 [30] -1/-1/-1->0->-1 [31] -1/-1/-1->0->-1\n", + "n192-024-092:113099:271788 [0] NCCL INFO P2P Chunksize set to 131072\n", + "n192-024-092:113099:271788 [0] NCCL INFO Connected all rings\n", + "n192-024-092:113099:271788 [0] NCCL INFO Connected all trees\n", + "n192-024-092:113099:271788 [0] NCCL INFO 32 coll channels, 0 collnet channels, 0 nvls channels, 32 p2p channels, 32 p2p channels per peer\n", + "n192-024-092:113099:271788 [0] NCCL INFO comm 0xb80a5d70 rank 0 nranks 1 cudaDev 0 nvmlDev 0 busId 1b000 commId 0xdd4574414117f6ad - Init COMPLETE\n", + "[[0.652 0.4424]]\n" + ] + } + ], + "source": [ + "from FlagEmbedding import FlagModel\n", + "sentences_1 = [\"样例数据-1\", \"样例数据-2\"]\n", + "sentences_2 = [\"样例数据-3\", \"样例数据-4\"]\n", + "model = FlagModel('BAAI/bge-large-zh-v1.5', \n", + " query_instruction_for_retrieval=\"为这个句子生成表示以用于检索相关文章:\",\n", + " use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation\n", + "embeddings_1 = model.encode(sentences_1)\n", + "embeddings_2 = model.encode(sentences_2)\n", + "similarity = embeddings_1 @ embeddings_2.T\n", + "print(similarity) \n", + "\n", + "# for s2p(short query to long passage) retrieval task, suggest to use encode_queries() which will automatically add the instruction to each query\n", + "# corpus in retrieval task can still use encode() or encode_corpus(), since they don't need instruction\n", + "queries = ['When was quantum field theory developed?']\n", + "passages = [\"Quantum field theory naturally began with the study of electromagnetic interactions, as the electromagnetic field was the only known classical field as of the 1920s.[8]:1\", \"Cumrun Vafa is a string theorist. His research is focused on the nature of quantum gravity and the relation between geometry and quantum field theories. He is known in the string theory community for his co-discovery, with Strominger, that the Bekenstein-Hawking entropy of a black hole can be accounted for by solitonic states of superstring theory, and for expounding the relation between geometry and field theories that arise through string dualities (culminating in the Gopakumar\\u2013Vafa conjecture). This topic has been known as \\\"geometric engineering of quantum field theories\\\". In 1997, he developed F-theory.\"]\n", + "q_embeddings = model.encode_queries(queries)\n", + "p_embeddings = model.encode(passages)\n", + "scores = q_embeddings @ p_embeddings.T\n", + "print(scores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.2 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.2" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/FlagEmbedding/reranker/embedding_run.py b/FlagEmbedding/reranker/embedding_run.py new file mode 100644 index 00000000..97691803 --- /dev/null +++ b/FlagEmbedding/reranker/embedding_run.py @@ -0,0 +1,125 @@ +import logging +import os +from pathlib import Path + +from transformers import AutoConfig, AutoTokenizer, TrainingArguments +from transformers import ( + HfArgumentParser, + set_seed, +) +from arguments import ModelArguments, DataArguments +from data import TrainDatasetForCE, GroupCollator +from modeling import CLEncoder +from trainer import CETrainer + +logger = logging.getLogger(__name__) +from pprint import pprint as pp +import sys +sys.path.append("/opt/tiger/FlagEmbedding") +from FlagEmbedding.reranker.data import TrainDatasetForCL +from utils import get_complete_last_checkpoint +import transformers +import os +os.environ["WANDB_DISABLED"]="true" + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) + +def main(): + parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: ModelArguments + data_args: DataArguments + training_args: TrainingArguments + + # for args in (model_args, data_args, training_args): pp(args) + + # check and load checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_complete_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + logger.info( + f"Output directory ({training_args.output_dir}) already exists and is empty." + "Train from scratch" + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + training_args.fp16, + ) + logger.info("Training/evaluation parameters %s", training_args) + logger.info("Model parameters %s", model_args) + logger.info("Data parameters %s", data_args) + + set_seed(training_args.seed) + + num_labels = 1 + + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=False, + ) + + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + cache_dir=model_args.cache_dir, + trust_remote_code=True + ) + _model_class = CLEncoder + + model = _model_class.from_pretrained( + model_args, data_args, training_args, + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + trust_remote_code=True + ) + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + logger.info(f"train start from {training_args.resume_from_checkpoint}") + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + logger.info(f"train start from {last_checkpoint}") + checkpoint = last_checkpoint + + train_dataset = TrainDatasetForCL(data_args, tokenizer=tokenizer) + _trainer_class = CETrainer + + trainer = _trainer_class( + model=model, + args=training_args, + train_dataset=train_dataset, + data_collator=GroupCollator(tokenizer), #这里依旧是拍平 + tokenizer=tokenizer + ) + trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_state() + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + +if __name__ == "__main__": + main() diff --git a/FlagEmbedding/reranker/modeling.py b/FlagEmbedding/reranker/modeling.py index 244c99fa..c4224fc1 100644 --- a/FlagEmbedding/reranker/modeling.py +++ b/FlagEmbedding/reranker/modeling.py @@ -2,10 +2,12 @@ import torch from torch import nn -from transformers import AutoModelForSequenceClassification, PreTrainedModel, TrainingArguments +from transformers import AutoModelForSequenceClassification, PreTrainedModel, TrainingArguments, AutoModel from transformers.modeling_outputs import SequenceClassifierOutput -from .arguments import ModelArguments, DataArguments +from arguments import ModelArguments, DataArguments +import torch +import torch.nn.functional as F logger = logging.getLogger(__name__) @@ -34,6 +36,7 @@ def forward(self, batch): ranker_out: SequenceClassifierOutput = self.hf_model(**batch, return_dict=True) logits = ranker_out.logits + #相当于是一个 group_size 个 cls 的多分类任务 if self.training: scores = logits.view( self.train_args.per_device_train_batch_size, @@ -53,7 +56,7 @@ def from_pretrained( cls, model_args: ModelArguments, data_args: DataArguments, train_args: TrainingArguments, *args, **kwargs ): - hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs) + hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs) #XLMR本身不带分类头 reranker = cls(hf_model, model_args, data_args, train_args) return reranker @@ -64,3 +67,137 @@ def save_pretrained(self, output_dir: str): for k, v in state_dict.items()}) self.hf_model.save_pretrained(output_dir, state_dict=state_dict) + +class CLEncoder(CrossEncoder): + def __init__(self, hf_model: PreTrainedModel, model_args: ModelArguments, data_args: DataArguments, + train_args: TrainingArguments, pooling_method = 'cls'): + super(CrossEncoder, self).__init__() + self.hf_model = hf_model + self.model_args = model_args + self.train_args = train_args + self.data_args = data_args + self.config = self.hf_model.config + self.pooling_method = pooling_method + + @classmethod + def from_pretrained( + cls, model_args: ModelArguments, data_args: DataArguments, train_args: TrainingArguments, + *args, **kwargs + ): + hf_model = AutoModel.from_pretrained(*args, **kwargs) + try: + del hf_model.classifier + except: + print("model has no classifier head") + + reranker = cls(hf_model, model_args, data_args, train_args) + return reranker + + def pooling(self, + last_hidden_state: torch.Tensor, + attention_mask: torch.Tensor = None): + if self.pooling_method == 'cls': + return last_hidden_state[:, 0] + elif self.pooling_method == 'mean': + s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1) + d = attention_mask.sum(dim=1, keepdim=True).float() + return s / d + + def get_embedding(self, input_ids, attention_mask): + hidden_state = self.hf_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1].cpu() + attention_mask = attention_mask.cpu() + embeddings = self.pooling(hidden_state, attention_mask) + embeddings = torch.nn.functional.normalize(embeddings, dim=-1) + return embeddings + + def infoNCELoss(self, anchor, positive, negatives, temperature=1): + # 计算所有样本的相似度 + pos_similarity = F.cosine_similarity(anchor, positive, dim=-1) + # 将anchor重复到与负样本相同数量的维度,以便计算 + neg_similarity = F.cosine_similarity(anchor, negatives, dim=-1) + # 合并正样本和负样本的相似度 + # print(pos_similarity.shape) + # print(neg_similarity.shape) + all_similarity = torch.cat([pos_similarity, neg_similarity]) + # 应用温度缩放 + all_similarity /= temperature + # 计算InfoNCE损失 + loss = - torch.log(torch.exp(pos_similarity)/torch.sum(torch.exp(all_similarity))) + return loss.mean() + + def batchloss(self, embeddings): + # 遍历每个batch计算损失 + losses = [] + for i in range(embeddings.size(0)): + # anchor embeddings + anchor = embeddings[i, 0].unsqueeze(0) # [1, 768] + # positive embeddings + positive = embeddings[i, 1].unsqueeze(0) # [1, 768] + # 除了anchor和positive之外的所有embeddings作为负样本 + negatives = embeddings[i, 2:] # [len(negs), 768] + # 计算当前batch的InfoNCE损失 + # print("anchor", anchor.shape) + # print("positive", positive.shape) + # print("negatives", negatives.shape) + loss = self.infoNCELoss(anchor, positive, negatives) + losses.append(loss) + # 计算整个batch的平均损失 + batch_loss = torch.mean(torch.stack(losses)) + return batch_loss + + def forward(self, batch): + embeddings = self.get_embedding(batch["input_ids"], batch["attention_mask"]) + embeddings = embeddings.reshape(self.train_args.per_device_train_batch_size, self.data_args.train_group_size+1, -1) + # print("embeddings", embeddings.shape) + loss = self.batchloss(embeddings).cuda() + #相当于是一个 group_size 个 cls 的多分类任务 + if self.training: + return SequenceClassifierOutput( + loss=loss, + hidden_states=embeddings, + ) + else: + return embeddings + +# 投影头 +class SimpleResBlock(nn.Module): + def __init__(self, channels=768): + super().__init__() + self.pre_norm = nn.LayerNorm(channels) + + self.proj = nn.Sequential( + nn.Linear(channels, channels), + nn.GELU(), + nn.Linear(channels, channels) + ) + def forward(self, x): + x = self.pre_norm(x) + return x + self.proj(x) + +class CLProjEncoder(CLEncoder): + def __init__(self, hf_model: PreTrainedModel, model_args: ModelArguments, data_args: DataArguments, train_args: TrainingArguments): + super().__init__(hf_model, model_args, data_args, train_args) + channels = 768 + # self.pre_norm = nn.LayerNorm(channels) + self.proj = nn.Sequential( + nn.Linear(channels, channels), + nn.GELU(), + nn.Linear(channels, channels) + ) + + def forward(self, batch): + embeddings = self.get_embedding(**batch) + embeddings = embeddings.reshape(self.train_args.per_device_train_batch_size, self.data_args.train_group_size+1, -1) + # 对 query 做一投影 + querys = embeddings[:,0,:] + embeddings[:,0,:] = self.proj(querys.cuda()).cpu() + # print("embeddings", embeddings.shape) + loss = self.batchloss(embeddings).cuda() + #相当于是一个 group_size 个 cls 的多分类任务 + if self.training: + return SequenceClassifierOutput( + loss=loss, + hidden_states=embeddings, + ) + else: + return embeddings \ No newline at end of file diff --git a/FlagEmbedding/reranker/run.py b/FlagEmbedding/reranker/run.py index 81a51850..14cf97b5 100644 --- a/FlagEmbedding/reranker/run.py +++ b/FlagEmbedding/reranker/run.py @@ -2,19 +2,39 @@ import os from pathlib import Path -from transformers import AutoConfig, AutoTokenizer, TrainingArguments +from transformers import AutoConfig, AutoTokenizer, TrainingArguments, TrainerCallback, TrainerState, TrainerControl from transformers import ( HfArgumentParser, set_seed, ) -from .arguments import ModelArguments, DataArguments -from .data import TrainDatasetForCE, GroupCollator -from .modeling import CrossEncoder -from .trainer import CETrainer +from arguments import ModelArguments, DataArguments +from data import TrainDatasetForCE, GroupCollator, TrainDatasetForCL +from modeling import CLEncoder, CLProjEncoder, CrossEncoder +from trainer import CETrainer logger = logging.getLogger(__name__) - +from pprint import pprint as pp +import sys +sys.path.append("/opt/tiger/FlagEmbedding") +from utils import get_complete_last_checkpoint +import transformers +import os +os.environ["WANDB_DISABLED"]="true" + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) + +# 记录similarity pos and neg +# class Loggingback(TrainerCallback): +# def on_train_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs): +# if state.global_step % 10 == 0: +# control. def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) @@ -23,15 +43,22 @@ def main(): data_args: DataArguments training_args: TrainingArguments - if ( - os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." - ) + # for args in (model_args, data_args, training_args): pp(args) + + # check and load checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir: + last_checkpoint = get_complete_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + logger.info( + f"Output directory ({training_args.output_dir}) already exists and is empty." + "Train from scratch" + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) # Setup logging logging.basicConfig( @@ -64,8 +91,15 @@ def main(): model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, cache_dir=model_args.cache_dir, + trust_remote_code=True ) - _model_class = CrossEncoder + + if model_args.model_type=="CrossEncoder": _model_class = CrossEncoder + elif model_args.model_type == "CLEncoder": _model_class = CLEncoder + else: _model_class = CLProjEncoder + + if model_args.model_type=="CrossEncoder": train_dataset = TrainDatasetForCE(data_args, tokenizer=tokenizer) + else: train_dataset = TrainDatasetForCL(data_args, tokenizer=tokenizer) model = _model_class.from_pretrained( model_args, data_args, training_args, @@ -73,9 +107,17 @@ def main(): from_tf=bool(".ckpt" in model_args.model_name_or_path), config=config, cache_dir=model_args.cache_dir, + trust_remote_code=True ) - train_dataset = TrainDatasetForCE(data_args, tokenizer=tokenizer) + checkpoint = None + if training_args.resume_from_checkpoint is not None: + logger.info(f"train start from {training_args.resume_from_checkpoint}") + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + logger.info(f"train start from {last_checkpoint}") + checkpoint = last_checkpoint + _trainer_class = CETrainer trainer = _trainer_class( model=model, @@ -84,11 +126,13 @@ def main(): data_collator=GroupCollator(tokenizer), tokenizer=tokenizer ) + trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_state() + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + # Path(training_args.output_dir).mkdir(parents=True, exist_ok=True) - Path(training_args.output_dir).mkdir(parents=True, exist_ok=True) - - trainer.train() - trainer.save_model() + # trainer.train() + # trainer.save_model() if __name__ == "__main__": diff --git a/FlagEmbedding/reranker/trainer.py b/FlagEmbedding/reranker/trainer.py index 7dbc5881..60d8e215 100644 --- a/FlagEmbedding/reranker/trainer.py +++ b/FlagEmbedding/reranker/trainer.py @@ -5,7 +5,7 @@ import torch from transformers.trainer import Trainer -from .modeling import CrossEncoder +from modeling import CrossEncoder logger = logging.getLogger(__name__) @@ -29,3 +29,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): def compute_loss(self, model: CrossEncoder, inputs): return model(inputs)['loss'] + + def _load_from_checkpoint(self, resume_from_checkpoint, model=None): + model = self.model.hf_model + super()._load_from_checkpoint(resume_from_checkpoint, model) \ No newline at end of file diff --git a/debug_args.ipynb b/debug_args.ipynb new file mode 100644 index 00000000..32465205 --- /dev/null +++ b/debug_args.ipynb @@ -0,0 +1,258 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\"--output_dir\", \"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/toy\", \"--model_name_or_path\", \"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/saved_model_v100\", \"--train_data\", \"/opt/tiger/toy_data.jsonl\", \"--learning_rate\", \"1e-5\", \"--fp16\", \"--num_train_epochs\", \"3\", \"--per_device_train_batch_size\", \"4\", \"--gradient_accumulation_steps\", \"4\", \"--dataloader_drop_last\", \"--train_group_size\", \"16\", \"--max_len\", \"512\", \"--weight_decay\", \"0.01\", \"--logging_steps\", \"10\", \"--save_strategy\", \"epoch\", \"--save_steps\", \"1\", \"--save_total_limit\", \"3\"]\n" + ] + } + ], + "source": [ + "import json\n", + "args = (str(\"--output_dir /mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/toy --model_name_or_path /mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/saved_model_v100 --train_data /opt/tiger/toy_data.jsonl --learning_rate 1e-5 --fp16 --num_train_epochs 3 --per_device_train_batch_size 4 --gradient_accumulation_steps 4 --dataloader_drop_last --train_group_size 16 --max_len 512 --weight_decay 0.01 --logging_steps 10 --save_strategy epoch --save_steps 1 --save_total_limit 3\".split(\" \")).replace(\"\\'\",\"\\\"\"))\n", + "print(args)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import transformers\n", + "from transformers import AutoTokenizer, AutoConfig, AutoModel\n", + "from transformers import TextGenerationPipeline, AutoModelForCausalLM, LlamaTokenizerFast\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "###config: RobertaConfig {\n", + " \"_name_or_path\": \"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/encoder_simcse_group15_batch1_a100_bge_manner/checkpoint-8000\",\n", + " \"architectures\": [\n", + " \"RobertaModel\"\n", + " ],\n", + " \"attention_probs_dropout_prob\": 0.1,\n", + " \"bos_token_id\": 0,\n", + " \"classifier_dropout\": null,\n", + " \"eos_token_id\": 2,\n", + " \"gradient_checkpointing\": false,\n", + " \"hidden_act\": \"gelu\",\n", + " \"hidden_dropout_prob\": 0.1,\n", + " \"hidden_size\": 1024,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 4096,\n", + " \"layer_norm_eps\": 1e-05,\n", + " \"max_position_embeddings\": 514,\n", + " \"model_type\": \"roberta\",\n", + " \"num_attention_heads\": 16,\n", + " \"num_hidden_layers\": 24,\n", + " \"output_past\": true,\n", + " \"pad_token_id\": 1,\n", + " \"position_embedding_type\": \"absolute\",\n", + " \"torch_dtype\": \"float32\",\n", + " \"transformers_version\": \"4.40.2\",\n", + " \"type_vocab_size\": 1,\n", + " \"use_cache\": true,\n", + " \"use_pooler_layer\": true,\n", + " \"vocab_size\": 250002\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "config = transformers.AutoConfig.from_pretrained(\"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/encoder_simcse_group15_batch1_a100_bge_manner/checkpoint-8000\")\n", + "config.use_pooler_layer = True\n", + "# config.loss_type = model_args.loss_type\n", + "# config.num_labels = model_args.num_labels\n", + "# config.margin = model_args.margin\n", + "# logging(f'Loss Function: {config.loss_type}')\n", + "print('###config: ', config)\n", + "# model = AutoModelForCausalLM.from_pretrained(\"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/encoder_simcse_group15_batch1_a100_bge_manner/checkpoint-8000\", device_map=\"auto\")\n", + "# tokenizer = AutoTokenizer.from_pretrained(\"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/encoder_simcse_group15_batch1_a100_bge_manner/checkpoint-8000\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "model = AutoModel.from_pretrained(\n", + " \"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/encoder_simcse_group15_batch1_a100_bge_manner/checkpoint-8000\",\n", + " config=config\n", + ")\n", + "tokenizer = AutoTokenizer.from_pretrained(\"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/encoder_simcse_group15_batch1_a100_bge_manner/checkpoint-8000\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RobertaModel(\n", + " (embeddings): RobertaEmbeddings(\n", + " (word_embeddings): Embedding(250002, 1024, padding_idx=1)\n", + " (position_embeddings): Embedding(514, 1024, padding_idx=1)\n", + " (token_type_embeddings): Embedding(1, 1024)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): RobertaEncoder(\n", + " (layer): ModuleList(\n", + " (0-23): 24 x RobertaLayer(\n", + " (attention): RobertaAttention(\n", + " (self): RobertaSelfAttention(\n", + " (query): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (key): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (value): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): RobertaSelfOutput(\n", + " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): RobertaIntermediate(\n", + " (dense): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (intermediate_act_fn): GELUActivation()\n", + " )\n", + " (output): RobertaOutput(\n", + " (dense): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): RobertaPooler(\n", + " (dense): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': tensor([[ 0, 33600, 31, 8999, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}\n" + ] + } + ], + "source": [ + "query = tokenizer(\"hello world\", return_tensors=\"pt\")\n", + "print(query)\n", + "output = model(**query)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "odict_keys(['last_hidden_state', 'pooler_output'])\n", + "tensor([[-0.4685, -0.1635, -0.0385, ..., -0.1976, 0.1712, 0.0859]],\n", + " grad_fn=)\n" + ] + } + ], + "source": [ + "print(output.keys())\n", + "print(output.pooler_output)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "odict_keys(['last_hidden_state', 'pooler_output'])\n", + "tensor([[-0.4685, -0.1635, -0.0385, ..., -0.1976, 0.1712, 0.0859]],\n", + " grad_fn=)\n" + ] + } + ], + "source": [ + "model1 = AutoModel.from_pretrained(\n", + " \"/mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/encoder_simcse_group15_batch1_a100_bge_manner/checkpoint-8000\"\n", + ")\n", + "output1 = model1(**query)\n", + "print(output1.keys())\n", + "print(output.pooler_output)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.2 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.2" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/finetune/toy_finetune_data_minedHN.jsonl b/examples/finetune/toy_finetune_data_minedHN.jsonl new file mode 100644 index 00000000..a51c6a3c --- /dev/null +++ b/examples/finetune/toy_finetune_data_minedHN.jsonl @@ -0,0 +1,10 @@ +{"query": "Five women walk along a beach wearing flip-flops.", "pos": ["Some women with flip-flops on, are walking along the beach"], "neg": ["The seal of Missouri is perfect.", "George Bush told the Republicans there was no way he would let them even consider this foolish idea, against his top advisors advice.", "They sold their home because they were retiring and not because of the loan.", "A person gives a speech.", "The street was lined with white-painted houses.", "Critical factors for essential activities are set out.", "A man is a pilot of an airplane.", "Conrad was being plotted against, to be hit on the head.", "The fatal dose was not taken when the murderer thought it would be.", "Mother Teresa is an easy choice.", "It lays out critical activities but makes no provision for critical factors related to those activities.", "It is only staged on Winter afternoons in Palma's large bullring.", "A group of Indians are having a funeral", "No matter how old people get they never forget. ", "The morning sunlight was shining brightly and it was warm. "]} +{"query": "A woman standing on a high cliff on one leg looking over a river.", "pos": ["A woman is standing on a cliff."], "neg": ["This is definitely not an endorsement.", "Two men watching a magic show.", "There was a reform in 1996.", "George Bush told the Republicans there was no way he would let them even consider this foolish idea, against his top advisors advice.", "An athlete is competing in the 1500 meter swimming competition.", "The state would prefer for you to do that.", "Conrad was being plotted against, to be hit on the head.", "A man is in a city.", "a dog is running", "man at picnics cut steak", "A girl is wearing blue.", "Meanwhile, the mainland was empty of population.", "A woman sits on a chair.", "Several chefs are sitting down and talking about food.", "Neither the Globe or Mail had comments on the current state of Canada's road system. "]} +{"query": "Two woman are playing instruments; one a clarinet, the other a violin.", "pos": ["Some people are playing a tune."], "neg": ["man at picnics cut steak", "A man is a pilot of an airplane.", "A child is reading in her bedroom.", "This is definitely not an endorsement.", "Neither the Globe or Mail had comments on the current state of Canada's road system. ", "Person in black clothing, with white bandanna and sunglasses waits at a bus stop.", "I face a serious problem at eighteen years old. ", "A boy is sitting outside playing in the sand.", "A man is skiing down a mountain.", "A woman is jogging in the park.", "People watching a spaceship launch.", "Ended as soon as I received the wire.", "The Spring Creek facility is old and outdated.", "A person gives a speech.", "The girl is standing, leaning against the archway."]} +{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A man is skiing down a mountain.", "They sold their home because they were retiring and not because of the loan.", "A group of Indians are having a funeral", "The man is talking about hawaii.", "People watching a spaceship launch.", "The street was lined with white-painted houses.", "A woman is riding her bike.", "Two people jumped off the dock.", "Ended as soon as I received the wire.", "The 4 women are sitting on the beach.", "The Spring Creek facility is old and outdated.", "A woman sits on a chair.", "Steele did not keep her original story.", "People are assembled in protest.", "Meanwhile, the mainland was empty of population."]} +{"query": "A yellow dog running along a forest path.", "pos": ["a dog is running"], "neg": ["The man sits at the table and eats food.", "People watching a spaceship launch.", "Nobody is jumping", "Conrad was being plotted against, to be hit on the head.", "The morning sunlight was shining brightly and it was warm. ", "This is definitely not an endorsement.", "A man is in a city.", "A girl is wearing blue.", "A man is a pilot of an airplane.", "A woman sits on a chair.", "It's worth being able to go at a pace you prefer.", "A woman is riding her bike.", "Critical factors for essential activities are set out.", "The child is wearing black.", "Steele did not keep her original story."]} +{"query": "It sets out essential activities in each phase along with critical factors related to those activities.", "pos": ["Critical factors for essential activities are set out."], "neg": ["A girl is wearing blue.", "Some women with flip-flops on, are walking along the beach", "Person on bike", "Financing is an issue for us in public schools.", "man at picnics cut steak", "The people are watching a funeral procession.", "The child is wearing black.", "It is only staged on Winter afternoons in Palma's large bullring.", "The rule discourages people to pay their child support.", "A woman sits on a chair.", "Several chefs are sitting down and talking about food.", "a dog is running", "The street was lined with white-painted houses.", "Steele did not keep her original story.", "The family was falling apart."]} +{"query": "A man giving a speech in a restaurant.", "pos": ["A person gives a speech."], "neg": ["The man sits at the table and eats food.", "A man is skiing down a mountain.", "The Spring Creek facility is old and outdated.", "I face a serious problem at eighteen years old. ", "The child is wearing black.", "No matter how old people get they never forget. ", "Two children is sleeping.", "People watching a spaceship launch.", "The Commission notes that no significant alternatives were considered.", "A man in a vest sits in a car.", "A man is a pilot of an airplane.", "The people are watching a funeral procession.", "Steele did not keep her original story.", "A child is reading in her bedroom.", "Two people jumped off the dock."]} +{"query": "Indians having a gathering with coats and food and drinks.", "pos": ["A group of Indians are having a gathering with food and drinks"], "neg": ["The street was lined with white-painted houses.", "The 4 women are sitting on the beach.", "People watching a spaceship launch.", "The man sits at the table and eats food.", "This is definitely not an endorsement.", "An athlete is competing in the 1500 meter swimming competition.", "A man is in a city.", "It is calming to be assaulted.", "The girl is standing, leaning against the archway.", "The seal of Missouri is perfect.", "It is only staged on Winter afternoons in Palma's large bullring.", "She's not going to court to clear her record.", "Two men watching a magic show.", "A group of women watch soap operas.", "The family was falling apart."]} +{"query": "A woman with violet hair rides her bicycle outside.", "pos": ["A woman is riding her bike."], "neg": ["A group of women watch soap operas.", "a cat is running", "A child is reading in her bedroom.", "The Spring Creek facility is old and outdated.", "Two children is sleeping.", "The 4 women are sitting on the beach.", "It lays out critical activities but makes no provision for critical factors related to those activities.", "Several chefs are sitting down and talking about food.", "She's not going to court to clear her record.", "A group of people plays volleyball.", "People are assembled in protest.", "People watching a spaceship launch.", "The fatal dose was not taken when the murderer thought it would be.", "a fisherman is trying to catch a monkey", "The people are watching a funeral procession."]} +{"query": "A man pulls two women down a city street in a rickshaw.", "pos": ["A man is in a city."], "neg": ["We ran out of firewood and had to use pine needles for the fire.", "Nobody is jumping", "Steele did not keep her original story.", "The man is talking about hawaii.", "A group of Indians are having a gathering with food and drinks", "Person on bike", "A girl sits beside a boy.", "A girl is wearing blue.", "The morning sunlight was shining brightly and it was warm. ", "The Commission notes that no significant alternatives were considered.", "Mother Teresa is an easy choice.", "A group of Indians are having a funeral", "Right information can empower the legal service practices and the justice system. ", "A woman is jogging in the park.", "A person gives a speech."]} diff --git a/pretrain.py b/pretrain.py new file mode 100755 index 00000000..1752f7e9 --- /dev/null +++ b/pretrain.py @@ -0,0 +1,22 @@ +import socket +import contextlib +import os +import sys +import torch +os.environ["WANDB_DISABLED"]="true" +args = (" ").join(sys.argv) +# 使用示例 +num_gpus = torch.cuda.device_count() +os.system("cd /opt/tiger/FlagEmbedding") +if not os.path.exists("/opt/tiger/train_15neg"): os.system("cp -r /mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/train_15neg /opt/tiger/train_15neg") + +#——————————————————————————————————————————————————debug——————————————————————————————————————————————————————————# +# args = "--output_dir /mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/toy --model_name_or_path /mnt/bn/data-tns-live-llm/leon/experiments/llm/fcbank/xlmr/models--FacebookAI--xlm-roberta-base/snapshots/e73636d4f797dec63c3081bb6ed5c7b0bb3f2089/ --train_data /opt/tiger/FlagEmbedding/examples/finetune/toy_finetune_data.jsonl --learning_rate 1e-5 --fp16 --num_train_epochs 5 --per_device_train_batch_size 2 --gradient_accumulation_steps 4 --dataloader_drop_last --train_group_size 10 --max_len 512 --weight_decay 0.01 --logging_steps 10 --save_strategy epoch --save_steps 1 --save_total_limit 3" +# print(args) +# num_gpus = 1 +#——————————————————————————————————————————————————debug——————————————————————————————————————————————————————————# + +# 构建训练命令 +command = f"""torchrun --rdzv_backend c10d --rdzv_endpoint localhost:0 --nproc_per_node {num_gpus} {args}""" +# 执行命令 +os.system(command) \ No newline at end of file diff --git a/tmp.py b/tmp.py new file mode 100644 index 00000000..16d07c33 --- /dev/null +++ b/tmp.py @@ -0,0 +1,5 @@ +import os +import sys +import torch +args = (" ").join(sys.argv) +print(args) \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 00000000..32a9cd2b --- /dev/null +++ b/utils.py @@ -0,0 +1,33 @@ +import os +import re + +PREFIX_CHECKPOINT_DIR = "checkpoint" +_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") + + +def get_complete_last_checkpoint(folder): + """ + because the checkpoint saving may be killed by the process kill, we need to get the real last checkpoint, + check if the last checkpoint is has same file number with the second last one + """ + content = os.listdir(folder) + checkpoints = [ + path + for path in content + if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path)) + ] + if len(checkpoints) == 0: + return + sorted_checkpoints = sorted(checkpoints, key=lambda x: int(_re_checkpoint.search(x).group(1))) + last_checkpoint = os.path.join(folder, sorted_checkpoints[-1]) + if len(sorted_checkpoints) >= 2: + second_last_checkpoint = os.path.join(folder, sorted_checkpoints[-2]) + else: + second_last_checkpoint = last_checkpoint + # check if the two file have same file number + last_checkpoint_file = os.listdir(last_checkpoint) + second_last_checkpoint_file = os.listdir(second_last_checkpoint) + if len(last_checkpoint_file) == len(second_last_checkpoint_file): + return last_checkpoint + else: + return second_last_checkpoint