Skip to content

Commit

Permalink
Flair: Add support for PoS Tagging Models & Version Update (#440)
Browse files Browse the repository at this point in the history
* flair: add support for PoS tagging models

* flair: bump requirement to latest 0.14.0 release

* flair: add support for PoS tagging models

* flair: add support for PoS tagging models

* flair: add support for multiple model tests (incl. fixing unsupported task pipeline)

* flair: add parameterized test cases (incl. setup class cache clearing)

* flair: apply black to test class
  • Loading branch information
stefan-it authored Aug 9, 2024
1 parent db248d4 commit 7f7255a
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 21 deletions.
39 changes: 24 additions & 15 deletions docker_images/flair/app/pipelines/token_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, List

from app.pipelines import Pipeline
from flair.data import Sentence
from flair.data import Sentence, Span, Token
from flair.models import SequenceTagger


Expand All @@ -27,21 +27,30 @@ def __call__(self, inputs: str) -> List[Dict[str, Any]]:
"""
sentence: Sentence = Sentence(inputs)

# Also show scores for recognized NEs
self.tagger.predict(sentence, label_name="predicted")
self.tagger.predict(sentence)

entities = []
for span in sentence.get_spans("predicted"):
if len(span.tokens) == 0:
continue
current_entity = {
"entity_group": span.tag,
"word": span.text,
"start": span.tokens[0].start_position,
"end": span.tokens[-1].end_position,
"score": span.score,
}

entities.append(current_entity)
for label in sentence.get_labels():
current_data_point = label.data_point
if isinstance(current_data_point, Token):
current_entity = {
"entity_group": current_data_point.tag,
"word": current_data_point.text,
"start": current_data_point.start_position,
"end": current_data_point.end_position,
"score": current_data_point.score,
}
entities.append(current_entity)
elif isinstance(current_data_point, Span):
if not current_data_point.tokens:
continue
current_entity = {
"entity_group": current_data_point.tag,
"word": current_data_point.text,
"start": current_data_point.tokens[0].start_position,
"end": current_data_point.tokens[-1].end_position,
"score": current_data_point.score,
}
entities.append(current_entity)

return entities
2 changes: 1 addition & 1 deletion docker_images/flair/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
starlette==0.27.0
pydantic==1.8.2
flair @ git+https://github.com/flairNLP/flair@b18aff236098fc6623de8bdb4c8b50e4bfe7f91f
flair @ git+https://github.com/flairNLP/flair@e17ab1234fcfed2b089d8ef02b99949d520382d2
api-inference-community==0.0.25
10 changes: 7 additions & 3 deletions docker_images/flair/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict
from typing import Dict, List
from unittest import TestCase, skipIf

from app.main import ALLOWED_TASKS, get_pipeline
Expand All @@ -8,7 +8,9 @@
# Must contain at least one example of each implemented pipeline
# Tests do not check the actual values of the model output, so small dummy
# models are recommended for faster tests.
TESTABLE_MODELS: Dict[str, str] = {"token-classification": "flair/chunk-english-fast"}
TESTABLE_MODELS: Dict[str, List[str]] = {
"token-classification": ["flair/chunk-english-fast", "flair/upos-english-fast"]
}


ALL_TASKS = {
Expand All @@ -35,5 +37,7 @@ def test_unsupported_tasks(self):
unsupported_tasks = ALL_TASKS - ALLOWED_TASKS.keys()
for unsupported_task in unsupported_tasks:
with self.subTest(msg=unsupported_task, task=unsupported_task):
os.environ["TASK"] = unsupported_task
os.environ["MODEL_ID"] = "XX"
with self.assertRaises(EnvironmentError):
get_pipeline(unsupported_task, model_id="XX")
get_pipeline()
13 changes: 11 additions & 2 deletions docker_images/flair/tests/test_api_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest import TestCase, skipIf

from app.main import ALLOWED_TASKS
from parameterized import parameterized_class
from starlette.testclient import TestClient
from tests.test_api import TESTABLE_MODELS

Expand All @@ -11,17 +12,25 @@
"token-classification" not in ALLOWED_TASKS,
"token-classification not implemented",
)
@parameterized_class(
[{"model_id": model_id} for model_id in TESTABLE_MODELS["token-classification"]]
)
class TokenClassificationTestCase(TestCase):
def setUp(self):
model_id = TESTABLE_MODELS["token-classification"]
self.old_model_id = os.getenv("MODEL_ID")
self.old_task = os.getenv("TASK")
os.environ["MODEL_ID"] = model_id
os.environ["MODEL_ID"] = self.model_id
os.environ["TASK"] = "token-classification"
from app.main import app

self.app = app

@classmethod
def setUpClass(cls):
from app.main import get_pipeline

get_pipeline.cache_clear()

def tearDown(self):
if self.old_model_id is not None:
os.environ["MODEL_ID"] = self.old_model_id
Expand Down

0 comments on commit 7f7255a

Please sign in to comment.