diff --git a/FlagEmbedding/abc/finetune/embedder/AbsTrainer.py b/FlagEmbedding/abc/finetune/embedder/AbsTrainer.py index 38944ed8..73afeb35 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsTrainer.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsTrainer.py @@ -2,6 +2,8 @@ from typing import Optional from abc import ABC, abstractmethod from transformers.trainer import Trainer +from sentence_transformers import SentenceTransformer, models +# from transformers.trainer import * logger = logging.getLogger(__name__) @@ -35,3 +37,14 @@ def compute_loss(self, model, inputs, return_outputs=False, **kwargs): loss = outputs.loss return (loss, outputs) if return_outputs else loss + + @staticmethod + def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls', normalized: bool = True): + word_embedding_model = models.Transformer(ckpt_dir) + pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=pooling_mode) + if normalized: + normalize_layer = models.Normalize() + model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normalize_layer], device='cpu') + else: + model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu') + model.save(ckpt_dir) diff --git a/FlagEmbedding/finetune/embedder/encoder_only/base/trainer.py b/FlagEmbedding/finetune/embedder/encoder_only/base/trainer.py index affa34a8..6e768d66 100644 --- a/FlagEmbedding/finetune/embedder/encoder_only/base/trainer.py +++ b/FlagEmbedding/finetune/embedder/encoder_only/base/trainer.py @@ -1,5 +1,4 @@ import os -import torch import logging from typing import Optional @@ -32,13 +31,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): f'does not support save interface') else: self.model.save(output_dir) - if self.tokenizer is not None and self.is_world_process_zero(): - self.tokenizer.save_pretrained(output_dir) - - torch.save(self.args, os.path.join(output_dir, "training_args.bin")) - - # save the checkpoint for sentence-transformers library - # if self.is_world_process_zero(): - # save_ckpt_for_sentence_transformers(output_dir, - # pooling_mode=self.args.sentence_pooling_method, - # normlized=self.args.normlized) + if self.is_world_process_zero(): + self.save_ckpt_for_sentence_transformers(output_dir, + pooling_mode=self.args.sentence_pooling_method) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) \ No newline at end of file