Skip to content

Commit 9729aa0

Browse files
committed
Merge branch 'main' into v1.0-release
2 parents f4d9f5e + f387387 commit 9729aa0

File tree

10 files changed

+287
-28
lines changed

10 files changed

+287
-28
lines changed

README.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<img src="https://raw.githubusercontent.com/huggingface/setfit/main/assets/setfit.png">
22

33
<p align="center">
4-
🤗 <a href="https://huggingface.co/setfit" target="_blank">Models & Datasets</a> | 📕 <a href="https://huggingface.co/docs/setfit" target="_blank">Documentation</a> | 📖 <a href="https://huggingface.co/blog/setfit" target="_blank">Blog</a> | 📃 <a href="https://arxiv.org/abs/2209.11055" target="_blank">Paper</a>
4+
🤗 <a href="https://huggingface.co/models?library=setfit" target="_blank">Models</a> | 📊 <a href="https://huggingface.co/setfit" target="_blank">Datasets</a> | 📕 <a href="https://huggingface.co/docs/setfit" target="_blank">Documentation</a> | 📖 <a href="https://huggingface.co/blog/setfit" target="_blank">Blog</a> | 📃 <a href="https://arxiv.org/abs/2209.11055" target="_blank">Paper</a>
55
</p>
66

77
# SetFit - Efficient Few-shot Learning with Sentence Transformers
@@ -61,7 +61,10 @@ eval_dataset = dataset["validation"].select(range(100))
6161
test_dataset = dataset["validation"].select(range(100, len(dataset["validation"])))
6262

6363
# Load a SetFit model from Hub
64-
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
64+
model = SetFitModel.from_pretrained(
65+
"sentence-transformers/paraphrase-mpnet-base-v2",
66+
labels=["negative", "positive"],
67+
)
6568

6669
args = TrainingArguments(
6770
batch_size=16,
@@ -94,7 +97,7 @@ model = SetFitModel.from_pretrained("tomaarsen/setfit-paraphrase-mpnet-base-v2-s
9497
# Run inference
9598
preds = model.predict(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])
9699
print(preds)
97-
# tensor([1, 0], dtype=torch.int32)
100+
# ["positive", "negative"]
98101
```
99102

100103

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"evaluate>=0.3.0",
1717
"huggingface_hub>=0.13.0",
1818
"scikit-learn",
19+
"packaging",
1920
]
2021
ABSA_REQUIRE = ["spacy"]
2122
QUALITY_REQUIRE = ["black", "flake8", "isort", "tabulate"]
@@ -53,7 +54,7 @@ def combine_requirements(base_keys):
5354

5455
setup(
5556
name="setfit",
56-
version="1.0.1",
57+
version="1.0.2",
5758
description="Efficient few-shot learning with Sentence Transformers",
5859
long_description=README_TEXT,
5960
long_description_content_type="text/markdown",

src/setfit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.0.1"
1+
__version__ = "1.0.2"
22

33
import importlib
44
import os

src/setfit/modeling.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
import torch
2020
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
2121
from huggingface_hub.utils import validate_hf_hub_args
22-
from sentence_transformers import SentenceTransformer, models
22+
from packaging.version import Version, parse
23+
from sentence_transformers import SentenceTransformer
24+
from sentence_transformers import __version__ as sentence_transformers_version
25+
from sentence_transformers import models
2326
from sklearn.linear_model import LogisticRegression
2427
from sklearn.multiclass import OneVsRestClassifier
2528
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier
@@ -215,6 +218,7 @@ class SetFitModel(PyTorchModelHubMixin):
215218
normalize_embeddings: bool = False
216219
labels: Optional[List[str]] = None
217220
model_card_data: Optional[SetFitModelCardData] = field(default_factory=SetFitModelCardData)
221+
sentence_transformers_kwargs: Dict = field(default_factory=dict, repr=False)
218222

219223
attributes_to_save: Set[str] = field(
220224
init=False, repr=False, default_factory=lambda: {"normalize_embeddings", "labels"}
@@ -501,6 +505,11 @@ def predict_proba(
501505
inputs = [inputs]
502506
embeddings = self.encode(inputs, batch_size=batch_size, show_progress_bar=show_progress_bar)
503507
probs = self.model_head.predict_proba(embeddings)
508+
if isinstance(probs, list):
509+
if self.has_differentiable_head:
510+
probs = torch.stack(probs, axis=1)
511+
else:
512+
probs = np.stack(probs, axis=1)
504513
outputs = self._output_type_conversion(probs, as_numpy=as_numpy)
505514
return outputs[0] if is_singular else outputs
506515

@@ -600,6 +609,9 @@ def device(self) -> torch.device:
600609
Returns:
601610
torch.device: The device that the model is on.
602611
"""
612+
# SentenceTransformers.device is reliable from 2.3.0 onwards
613+
if parse(sentence_transformers_version) >= Version("2.3.0"):
614+
return self.model_body.device
603615
return self.model_body._target_device
604616

605617
def to(self, device: Union[str, torch.device]) -> "SetFitModel":
@@ -617,9 +629,10 @@ def to(self, device: Union[str, torch.device]) -> "SetFitModel":
617629
Returns:
618630
SetFitModel: Returns the original model, but now on the desired device.
619631
"""
620-
# Note that we must also set _target_device, or any SentenceTransformer.fit() call will reset
621-
# the body location
622-
self.model_body._target_device = device if isinstance(device, torch.device) else torch.device(device)
632+
# Note that we must also set _target_device with sentence-transformers <2.3.0,
633+
# or any SentenceTransformer.fit() call will reset the body location
634+
if parse(sentence_transformers_version) < Version("2.3.0"):
635+
self.model_body._target_device = device if isinstance(device, torch.device) else torch.device(device)
623636
self.model_body = self.model_body.to(device)
624637

625638
if self.has_differentiable_head:
@@ -696,10 +709,37 @@ def _from_pretrained(
696709
multi_target_strategy: Optional[str] = None,
697710
use_differentiable_head: bool = False,
698711
device: Optional[Union[torch.device, str]] = None,
712+
trust_remote_code: bool = False,
699713
**model_kwargs,
700714
) -> "SetFitModel":
701-
model_body = SentenceTransformer(model_id, cache_folder=cache_dir, use_auth_token=token, device=device)
702-
device = model_body._target_device
715+
sentence_transformers_kwargs = {
716+
"cache_folder": cache_dir,
717+
"use_auth_token": token,
718+
"device": device,
719+
"trust_remote_code": trust_remote_code,
720+
}
721+
if parse(sentence_transformers_version) >= Version("2.3.0"):
722+
sentence_transformers_kwargs = {
723+
"cache_folder": cache_dir,
724+
"token": token,
725+
"device": device,
726+
"trust_remote_code": trust_remote_code,
727+
}
728+
else:
729+
if trust_remote_code:
730+
raise ValueError(
731+
"The `trust_remote_code` argument is only supported for `sentence-transformers` >= 2.3.0."
732+
)
733+
sentence_transformers_kwargs = {
734+
"cache_folder": cache_dir,
735+
"use_auth_token": token,
736+
"device": device,
737+
}
738+
model_body = SentenceTransformer(model_id, **sentence_transformers_kwargs)
739+
if parse(sentence_transformers_version) >= Version("2.3.0"):
740+
device = model_body.device
741+
else:
742+
device = model_body._target_device
703743
model_body.to(device) # put `model_body` on the target device
704744

705745
# Try to load a SetFit config file
@@ -822,6 +862,7 @@ def _from_pretrained(
822862
model_head=model_head,
823863
multi_target_strategy=multi_target_strategy,
824864
model_card_data=model_card_data,
865+
sentence_transformers_kwargs=sentence_transformers_kwargs,
825866
**model_kwargs,
826867
)
827868

@@ -846,6 +887,10 @@ def _from_pretrained(
846887
Whether to apply normalization on the embeddings produced by the Sentence Transformer body.
847888
device (`Union[torch.device, str]`, *optional*):
848889
The device on which to load the SetFit model, e.g. `"cuda:0"`, `"mps"` or `torch.device("cuda")`.
890+
trust_remote_code (`bool`, defaults to `False`): Whether or not to allow for custom Sentence Transformers
891+
models defined on the Hub in their own modeling files. This option should only be set to True for
892+
repositories you trust and in which you have read the code, as it will execute code present on
893+
the Hub on your local machine. Defaults to False.
849894
850895
Example::
851896

src/setfit/span/modeling.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import copy
22
import os
3+
import re
34
import tempfile
45
import types
6+
from collections import defaultdict
57
from dataclasses import dataclass, field
68
from pathlib import Path
79
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
810

911
import torch
12+
from datasets import Dataset
1013
from huggingface_hub.utils import SoftTemporaryDirectory
1114

1215
from setfit.utils import set_docstring
@@ -148,7 +151,99 @@ class AbsaModel:
148151
aspect_model: AspectModel
149152
polarity_model: PolarityModel
150153

151-
def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
154+
def gold_aspect_spans_to_aspects_list(self, inputs: Dataset) -> List[List[slice]]:
155+
# First group inputs by text
156+
grouped_data = defaultdict(list)
157+
for sample in inputs:
158+
text = sample.pop("text")
159+
grouped_data[text].append(sample)
160+
161+
# Get the spaCy docs
162+
docs, _ = self.aspect_extractor(grouped_data.keys())
163+
164+
# Get the aspect spans for each doc by matching gold spans to the spaCy tokens
165+
aspects_list = []
166+
index = -1
167+
skipped_indices = []
168+
for doc, samples in zip(docs, grouped_data.values()):
169+
aspects_list.append([])
170+
for sample in samples:
171+
index += 1
172+
match_objects = re.finditer(re.escape(sample["span"]), doc.text)
173+
for i, match in enumerate(match_objects):
174+
if i == sample["ordinal"]:
175+
char_idx_start = match.start()
176+
char_idx_end = match.end()
177+
span = doc.char_span(char_idx_start, char_idx_end)
178+
if span is None:
179+
logger.warning(
180+
f"Aspect term {sample['span']!r} with ordinal {sample['ordinal']}, isn't a token in {doc.text!r} according to spaCy. "
181+
"Skipping this sample."
182+
)
183+
skipped_indices.append(index)
184+
continue
185+
aspects_list[-1].append(slice(span.start, span.end))
186+
return docs, aspects_list, skipped_indices
187+
188+
def predict_dataset(self, inputs: Dataset) -> Dataset:
189+
if set(inputs.column_names) >= {"text", "span", "ordinal"}:
190+
pass
191+
elif set(inputs.column_names) >= {"text", "span"}:
192+
inputs = inputs.add_column("ordinal", [0] * len(inputs))
193+
else:
194+
raise ValueError(
195+
"`inputs` must be either a `str`, a `List[str]`, or a `datasets.Dataset` with columns `text` and `span` and optionally `ordinal`. "
196+
f"Found a dataset with these columns: {inputs.column_names}."
197+
)
198+
if "pred_polarity" in inputs.column_names:
199+
raise ValueError(
200+
"`predict_dataset` wants to add a `pred_polarity` column, but the input dataset already contains that column."
201+
)
202+
docs, aspects_list, skipped_indices = self.gold_aspect_spans_to_aspects_list(inputs)
203+
polarity_list = sum(self.polarity_model(docs, aspects_list), [])
204+
for index in skipped_indices:
205+
polarity_list.insert(index, None)
206+
return inputs.add_column("pred_polarity", polarity_list)
207+
208+
def predict(self, inputs: Union[str, List[str], Dataset]) -> Union[List[Dict[str, Any]], Dataset]:
209+
"""Predicts aspects & their polarities of the given inputs.
210+
211+
Example::
212+
213+
>>> from setfit import AbsaModel
214+
>>> model = AbsaModel.from_pretrained(
215+
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
216+
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
217+
... )
218+
>>> model.predict("The food and wine are just exquisite.")
219+
[{'span': 'food', 'polarity': 'positive'}, {'span': 'wine', 'polarity': 'positive'}]
220+
221+
>>> from setfit import AbsaModel
222+
>>> from datasets import load_dataset
223+
>>> model = AbsaModel.from_pretrained(
224+
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
225+
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
226+
... )
227+
>>> dataset = load_dataset("tomaarsen/setfit-absa-semeval-restaurants", split="train")
228+
>>> model.predict(dataset)
229+
Dataset({
230+
features: ['text', 'span', 'label', 'ordinal', 'pred_polarity'],
231+
num_rows: 3693
232+
})
233+
234+
Args:
235+
inputs (Union[str, List[str], Dataset]): Either a sentence, a list of sentences,
236+
or a dataset with columns `text` and `span` and optionally `ordinal`. This dataset
237+
contains gold aspects, and we only predict the polarities for them.
238+
239+
Returns:
240+
Union[List[Dict[str, Any]], Dataset]: Either a list of dictionaries with keys `span`
241+
and `polarity` if the input was a sentence or a list of sentences, or a dataset with
242+
columns `text`, `span`, `ordinal`, and `pred_polarity`.
243+
"""
244+
if isinstance(inputs, Dataset):
245+
return self.predict_dataset(inputs)
246+
152247
is_str = isinstance(inputs, str)
153248
inputs_list = [inputs] if is_str else inputs
154249
docs, aspects_list = self.aspect_extractor(inputs_list)

src/setfit/trainer.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,8 @@ def get_dataloader(
507507
args.sampling_strategy,
508508
max_pairs=max_pairs,
509509
)
510-
# shuffle_sampler = True can be dropped in for further 'randomising'
511-
shuffle_sampler = True if args.sampling_strategy == "unique" else False
512510
batch_size = min(args.embedding_batch_size, len(data_sampler))
513-
dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
511+
dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=False)
514512
loss = args.loss(self.model.model_body)
515513

516514
return dataloader, loss, batch_size
@@ -576,8 +574,8 @@ def _train_sentence_transformer(
576574
if args.use_amp:
577575
scaler = torch.cuda.amp.GradScaler()
578576

579-
model_body.to(model_body._target_device)
580-
loss_func.to(model_body._target_device)
577+
model_body.to(self.model.device)
578+
loss_func.to(self.model.device)
581579

582580
# Use smart batching
583581
train_dataloader.collate_fn = model_body.smart_batching_collate
@@ -625,8 +623,8 @@ def _train_sentence_transformer(
625623
data = next(data_iterator)
626624

627625
features, labels = data
628-
labels = labels.to(model_body._target_device)
629-
features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features))
626+
labels = labels.to(self.model.device)
627+
features = list(map(lambda batch: batch_to_device(batch, self.model.device), features))
630628

631629
if args.use_amp:
632630
with autocast():
@@ -673,10 +671,12 @@ def _train_sentence_transformer(
673671
step_to_load = dir_name[5:]
674672
logger.info(f"Loading best SentenceTransformer model from step {step_to_load}.")
675673
self.model.model_card_data.set_best_model_step(int(step_to_load))
674+
sentence_transformer_kwargs = self.model.sentence_transformers_kwargs
675+
sentence_transformer_kwargs["device"] = self.model.device
676676
self.model.model_body = SentenceTransformer(
677-
self.state.best_model_checkpoint, device=model_body._target_device
677+
self.state.best_model_checkpoint, **sentence_transformer_kwargs
678678
)
679-
self.model.model_body.to(model_body._target_device)
679+
self.model.model_body.to(self.model.device)
680680

681681
# Ensure logging the speed metrics
682682
num_train_samples = self.state.max_steps * args.embedding_batch_size # * args.gradient_accumulation_steps
@@ -736,8 +736,8 @@ def _evaluate_with_loss(
736736
tqdm(iter(eval_dataloader), total=eval_steps, leave=False, disable=not args.show_progress_bar), start=1
737737
):
738738
features, labels = data
739-
labels = labels.to(model_body._target_device)
740-
features = list(map(lambda batch: batch_to_device(batch, model_body._target_device), features))
739+
labels = labels.to(self.model.device)
740+
features = list(map(lambda batch: batch_to_device(batch, self.model.device), features))
741741

742742
if args.use_amp:
743743
with autocast():

src/setfit/trainer_distillation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,8 @@ def get_dataloader(
9393
data_sampler = ContrastiveDistillationDataset(
9494
input_data, cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs
9595
)
96-
# shuffle_sampler = True can be dropped in for further 'randomising'
97-
shuffle_sampler = True if args.sampling_strategy == "unique" else False
9896
batch_size = min(args.embedding_batch_size, len(data_sampler))
99-
dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
97+
dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=False)
10098
loss = args.loss(self.model.model_body)
10199
return dataloader, loss, batch_size
102100

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ def absa_model() -> AbsaModel:
1414
return AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm")
1515

1616

17+
@pytest.fixture()
18+
def trained_absa_model() -> AbsaModel:
19+
return AbsaModel.from_pretrained(
20+
"tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
21+
"tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
22+
)
23+
24+
1725
@pytest.fixture()
1826
def absa_dataset() -> Dataset:
1927
texts = [

0 commit comments

Comments
 (0)