From fa76bcc65f2d30a6ec2abcaa1393861480c9b685 Mon Sep 17 00:00:00 2001 From: ayang Date: Sat, 25 Jan 2025 11:40:25 +0800 Subject: [PATCH] pass kwargs into AutoModel.from_pretrained MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 比如local_files_only必须传进去,要不断网后没法使用。直接把所有kwargs传进去是没问题的,因为 transformers/models/auto/auto_factory.py from_pretrained 函数对参数进行了过滤。 --- FlagEmbedding/inference/embedder/encoder_only/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/FlagEmbedding/inference/embedder/encoder_only/base.py b/FlagEmbedding/inference/embedder/encoder_only/base.py index 2547ab24..0c759bca 100644 --- a/FlagEmbedding/inference/embedder/encoder_only/base.py +++ b/FlagEmbedding/inference/embedder/encoder_only/base.py @@ -74,12 +74,14 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, - cache_dir=cache_dir + cache_dir=cache_dir, + **kwargs ) self.model = AutoModel.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code, - cache_dir=cache_dir + cache_dir=cache_dir, + **kwargs ) def encode_queries(