Skip to content

Commit 76e3a32

Browse files
authored
Feat/zshot version (#17)
* 🎨 Improved structure of setup and init. * ✏️ Fixed minor typos and format in evaluator * ✅ Update evaluation tests to work with latest version of evaluate * 🐛 Fixed bug while importing version
1 parent 92cfc30 commit 76e3a32

File tree

6 files changed

+20
-13
lines changed

6 files changed

+20
-13
lines changed

setup.cfg

+3
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
[egg_info]
22
tag_svn_revision = true
3+
4+
[metadata]
5+
version = attr: zshot.__version__

setup.py

-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
this_directory = Path(__file__).parent
55
long_description = (this_directory / "README.md").read_text()
66

7-
version = '0.0.2'
87

98
setup(name='zshot',
10-
version=version,
119
description="Zero and Few shot named entity recognition",
1210
long_description_content_type='text/markdown',
1311
long_description=long_description,

zshot/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from zshot.zshot import MentionsExtractor, Linker, Zshot, PipelineConfig # noqa: F401
22
from zshot.utils.displacy import displacy # noqa: F401
3+
4+
__version__ = '0.0.3'

zshot/evaluation/evaluator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def prepare_pipeline(
3434
feature_extractor=None, # noqa: F821
3535
device: int = None,
3636
):
37-
pipe = super(TokenClassificationEvaluator, self).prepare_pipeline(model_or_pipeline, tokenizer, feature_extractor, device)
37+
pipe = super(TokenClassificationEvaluator, self).prepare_pipeline(model_or_pipeline, tokenizer,
38+
feature_extractor, device)
3839
return pipe
3940

4041

zshot/evaluation/zshot_evaluate.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
from prettytable import PrettyTable
66

77
from zshot.evaluation import load_medmentions, load_ontonotes
8-
from zshot.evaluation.dataset.dataset import DatasetWithEntities
98
from zshot.evaluation.evaluator import ZeroShotTokenClassificationEvaluator, MentionsExtractorEvaluator
109
from zshot.evaluation.pipeline import LinkerPipeline, MentionsExtractorPipeline
1110

1211

1312
def evaluate(nlp: spacy.language.Language,
14-
datasets: Union[DatasetWithEntities, List[DatasetWithEntities]],
13+
datasets: Union[str, List[str]],
1514
splits: Optional[Union[str, List[str]]] = None,
1615
metric: Optional[Union[str, EvaluationModule]] = None,
1716
batch_size: Optional[int] = 16) -> str:
@@ -31,6 +30,9 @@ def evaluate(nlp: spacy.language.Language,
3130
if type(splits) == str:
3231
splits = [splits]
3332

33+
if type(datasets) == str:
34+
datasets = [datasets]
35+
3436
result = {}
3537
field_names = ["Metric"]
3638
for dataset_name in datasets:

zshot/tests/evaluation/test_evaluation.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_prediction_token_based_evaluation_all_matching(self):
113113
dataset = get_dataset(gt, sentences)
114114

115115
custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification")
116-
metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1)]), dataset, "seqeval")
116+
metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1)]), dataset, metric="seqeval")
117117

118118
assert float(metrics["overall_precision"]) == 1.0
119119
assert float(metrics["overall_precision"]) == 1.0
@@ -128,7 +128,7 @@ def test_prediction_token_based_evaluation_overlapping_spans(self):
128128

129129
custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification")
130130
metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]), dataset,
131-
"seqeval")
131+
metric="seqeval")
132132

133133
assert float(metrics["overall_precision"]) == 1.0
134134
assert float(metrics["overall_precision"]) == 1.0
@@ -144,7 +144,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_expand(self):
144144
custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
145145
alignment_mode=AlignmentMode.expand)
146146
pipe = get_linker_pipe([('New Yo', 'FAC', 1)])
147-
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")
147+
metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval")
148148

149149
assert float(metrics["overall_precision"]) == 1.0
150150
assert float(metrics["overall_precision"]) == 1.0
@@ -160,7 +160,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_contract(self):
160160
custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
161161
alignment_mode=AlignmentMode.contract)
162162
pipe = get_linker_pipe([('New York i', 'FAC', 1)])
163-
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")
163+
metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval")
164164

165165
assert float(metrics["overall_precision"]) == 1.0
166166
assert float(metrics["overall_precision"]) == 1.0
@@ -176,7 +176,7 @@ def test_prediction_token_based_evaluation_partial_and_overlapping_spans(self):
176176
custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
177177
alignment_mode=AlignmentMode.contract)
178178
pipe = get_linker_pipe([('New York i', 'FAC', 1), ('w York', 'LOC', 0.7)])
179-
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")
179+
metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval")
180180

181181
assert float(metrics["overall_precision"]) == 1.0
182182
assert float(metrics["overall_precision"]) == 1.0
@@ -207,7 +207,8 @@ def test_prediction_token_based_evaluation_all_matching(self):
207207
dataset = get_dataset(gt, sentences)
208208

209209
custom_evaluator = MentionsExtractorEvaluator("token-classification")
210-
metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1)]), dataset, "seqeval")
210+
metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1)]), dataset,
211+
metric="seqeval")
211212

212213
assert float(metrics["overall_precision"]) == 1.0
213214
assert float(metrics["overall_precision"]) == 1.0
@@ -222,7 +223,7 @@ def test_prediction_token_based_evaluation_overlapping_spans(self):
222223

223224
custom_evaluator = MentionsExtractorEvaluator("token-classification")
224225
metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]),
225-
dataset, "seqeval")
226+
dataset, metric="seqeval")
226227

227228
assert float(metrics["overall_precision"]) == 1.0
228229
assert float(metrics["overall_precision"]) == 1.0
@@ -238,7 +239,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_expand(self):
238239
custom_evaluator = MentionsExtractorEvaluator("token-classification",
239240
alignment_mode=AlignmentMode.expand)
240241
pipe = get_mentions_extractor_pipe([('New Yo', 'FAC', 1)])
241-
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")
242+
metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval")
242243

243244
assert float(metrics["overall_precision"]) == 1.0
244245
assert float(metrics["overall_precision"]) == 1.0

0 commit comments

Comments
 (0)