1919import torch
2020from huggingface_hub import PyTorchModelHubMixin , hf_hub_download
2121from huggingface_hub .utils import validate_hf_hub_args
22- from sentence_transformers import SentenceTransformer , models
22+ from packaging .version import Version , parse
23+ from sentence_transformers import SentenceTransformer
24+ from sentence_transformers import __version__ as sentence_transformers_version
25+ from sentence_transformers import models
2326from sklearn .linear_model import LogisticRegression
2427from sklearn .multiclass import OneVsRestClassifier
2528from sklearn .multioutput import ClassifierChain , MultiOutputClassifier
@@ -215,6 +218,7 @@ class SetFitModel(PyTorchModelHubMixin):
215218 normalize_embeddings : bool = False
216219 labels : Optional [List [str ]] = None
217220 model_card_data : Optional [SetFitModelCardData ] = field (default_factory = SetFitModelCardData )
221+ sentence_transformers_kwargs : Dict = field (default_factory = dict , repr = False )
218222
219223 attributes_to_save : Set [str ] = field (
220224 init = False , repr = False , default_factory = lambda : {"normalize_embeddings" , "labels" }
@@ -501,6 +505,11 @@ def predict_proba(
501505 inputs = [inputs ]
502506 embeddings = self .encode (inputs , batch_size = batch_size , show_progress_bar = show_progress_bar )
503507 probs = self .model_head .predict_proba (embeddings )
508+ if isinstance (probs , list ):
509+ if self .has_differentiable_head :
510+ probs = torch .stack (probs , axis = 1 )
511+ else :
512+ probs = np .stack (probs , axis = 1 )
504513 outputs = self ._output_type_conversion (probs , as_numpy = as_numpy )
505514 return outputs [0 ] if is_singular else outputs
506515
@@ -600,6 +609,9 @@ def device(self) -> torch.device:
600609 Returns:
601610 torch.device: The device that the model is on.
602611 """
612+ # SentenceTransformers.device is reliable from 2.3.0 onwards
613+ if parse (sentence_transformers_version ) >= Version ("2.3.0" ):
614+ return self .model_body .device
603615 return self .model_body ._target_device
604616
605617 def to (self , device : Union [str , torch .device ]) -> "SetFitModel" :
@@ -617,9 +629,10 @@ def to(self, device: Union[str, torch.device]) -> "SetFitModel":
617629 Returns:
618630 SetFitModel: Returns the original model, but now on the desired device.
619631 """
620- # Note that we must also set _target_device, or any SentenceTransformer.fit() call will reset
621- # the body location
622- self .model_body ._target_device = device if isinstance (device , torch .device ) else torch .device (device )
632+ # Note that we must also set _target_device with sentence-transformers <2.3.0,
633+ # or any SentenceTransformer.fit() call will reset the body location
634+ if parse (sentence_transformers_version ) < Version ("2.3.0" ):
635+ self .model_body ._target_device = device if isinstance (device , torch .device ) else torch .device (device )
623636 self .model_body = self .model_body .to (device )
624637
625638 if self .has_differentiable_head :
@@ -696,10 +709,37 @@ def _from_pretrained(
696709 multi_target_strategy : Optional [str ] = None ,
697710 use_differentiable_head : bool = False ,
698711 device : Optional [Union [torch .device , str ]] = None ,
712+ trust_remote_code : bool = False ,
699713 ** model_kwargs ,
700714 ) -> "SetFitModel" :
701- model_body = SentenceTransformer (model_id , cache_folder = cache_dir , use_auth_token = token , device = device )
702- device = model_body ._target_device
715+ sentence_transformers_kwargs = {
716+ "cache_folder" : cache_dir ,
717+ "use_auth_token" : token ,
718+ "device" : device ,
719+ "trust_remote_code" : trust_remote_code ,
720+ }
721+ if parse (sentence_transformers_version ) >= Version ("2.3.0" ):
722+ sentence_transformers_kwargs = {
723+ "cache_folder" : cache_dir ,
724+ "token" : token ,
725+ "device" : device ,
726+ "trust_remote_code" : trust_remote_code ,
727+ }
728+ else :
729+ if trust_remote_code :
730+ raise ValueError (
731+ "The `trust_remote_code` argument is only supported for `sentence-transformers` >= 2.3.0."
732+ )
733+ sentence_transformers_kwargs = {
734+ "cache_folder" : cache_dir ,
735+ "use_auth_token" : token ,
736+ "device" : device ,
737+ }
738+ model_body = SentenceTransformer (model_id , ** sentence_transformers_kwargs )
739+ if parse (sentence_transformers_version ) >= Version ("2.3.0" ):
740+ device = model_body .device
741+ else :
742+ device = model_body ._target_device
703743 model_body .to (device ) # put `model_body` on the target device
704744
705745 # Try to load a SetFit config file
@@ -822,6 +862,7 @@ def _from_pretrained(
822862 model_head = model_head ,
823863 multi_target_strategy = multi_target_strategy ,
824864 model_card_data = model_card_data ,
865+ sentence_transformers_kwargs = sentence_transformers_kwargs ,
825866 ** model_kwargs ,
826867 )
827868
@@ -846,6 +887,10 @@ def _from_pretrained(
846887 Whether to apply normalization on the embeddings produced by the Sentence Transformer body.
847888 device (`Union[torch.device, str]`, *optional*):
848889 The device on which to load the SetFit model, e.g. `"cuda:0"`, `"mps"` or `torch.device("cuda")`.
890+ trust_remote_code (`bool`, defaults to `False`): Whether or not to allow for custom Sentence Transformers
891+ models defined on the Hub in their own modeling files. This option should only be set to True for
892+ repositories you trust and in which you have read the code, as it will execute code present on
893+ the Hub on your local machine. Defaults to False.
849894
850895 Example::
851896
0 commit comments