1- from typing import Dict , List , Optional , Tuple , Union
1+ from typing import Any , Dict , List , Optional , Tuple , Union
22
3+ import torch
34from haystack .components .embedders import (
45 SentenceTransformersDocumentEmbedder ,
56 SentenceTransformersTextEmbedder ,
@@ -33,7 +34,7 @@ def __init__(
3334 import sentence_transformers
3435
3536 class _IPEXSTTransformers (sentence_transformers .models .Transformer ):
36- def _load_model (self , model_name_or_path , config , cache_dir , ** model_args ):
37+ def _load_model (self , model_name_or_path , config , cache_dir , backend , ** model_args ):
3738 print ("Loading IPEX ST Transformer model" )
3839 optimized_intel_import .check ()
3940 self .auto_model = IPEXModel .from_pretrained (
@@ -89,23 +90,39 @@ def _load_auto_model(
8990 cache_folder : Optional [str ],
9091 revision : Optional [str ] = None ,
9192 trust_remote_code : bool = False ,
93+ local_files_only : bool = False ,
94+ model_kwargs : Optional [Dict [str , Any ]] = None ,
95+ tokenizer_kwargs : Optional [Dict [str , Any ]] = None ,
96+ config_kwargs : Optional [Dict [str , Any ]] = None ,
9297 ):
9398 """
9499 Creates a simple Transformer + Mean Pooling model and returns the modules
95100 """
101+
102+ shared_kwargs = {
103+ "token" : token ,
104+ "trust_remote_code" : trust_remote_code ,
105+ "revision" : revision ,
106+ "local_files_only" : local_files_only ,
107+ }
108+ model_kwargs = (
109+ shared_kwargs if model_kwargs is None else {** shared_kwargs , ** model_kwargs }
110+ )
111+ tokenizer_kwargs = (
112+ shared_kwargs
113+ if tokenizer_kwargs is None
114+ else {** shared_kwargs , ** tokenizer_kwargs }
115+ )
116+ config_kwargs = (
117+ shared_kwargs if config_kwargs is None else {** shared_kwargs , ** config_kwargs }
118+ )
119+
96120 transformer_model = _IPEXSTTransformers (
97121 model_name_or_path ,
98122 cache_dir = cache_folder ,
99- model_args = {
100- "token" : token ,
101- "trust_remote_code" : trust_remote_code ,
102- "revision" : revision ,
103- },
104- tokenizer_args = {
105- "token" : token ,
106- "trust_remote_code" : trust_remote_code ,
107- "revision" : revision ,
108- },
123+ model_args = model_kwargs ,
124+ tokenizer_args = tokenizer_kwargs ,
125+ config_args = config_kwargs ,
109126 )
110127 pooling_model = sentence_transformers .models .Pooling (
111128 transformer_model .get_word_embedding_dimension (), "mean"
@@ -114,7 +131,7 @@ def _load_auto_model(
114131
115132 @property
116133 def device (self ):
117- return "cpu"
134+ return torch . device ( "cpu" )
118135
119136 self .model = _IPEXSentenceTransformer (
120137 model_name_or_path = model ,
@@ -132,7 +149,7 @@ def ipex_model_warm_up(self):
132149 """
133150 Initializes the component.
134151 """
135- if not hasattr (self , "embedding_backend" ):
152+ if not getattr (self , "embedding_backend" , None ):
136153 self .embedding_backend = _IPEXSentenceTransformersEmbeddingBackend (
137154 model = self .model ,
138155 device = self .device .to_torch_str (),
0 commit comments