Skip to content

Commit ab1fb43

Browse files
authored
adjust test_dataset_with_taskmodule to use local TaskModule implementation (#182)
* adjust test_dataset_with_taskmodule to use local TaskModule implementation (and simplify) * move TestTaskModule to tests.common.taskmodule * fix import
1 parent a8abacf commit ab1fb43

File tree

3 files changed

+287
-58
lines changed

3 files changed

+287
-58
lines changed

Diff for: tests/common/__init__.py

Whitespace-only changes.

Diff for: tests/common/taskmodule.py

+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
"""
2+
workflow:
3+
document
4+
-> (input_encoding, target_encoding) -> task_encoding
5+
-> model_encoding -> model_output
6+
-> task_output
7+
-> document
8+
"""
9+
10+
import dataclasses
11+
import logging
12+
from typing import Dict, Iterator, List, Optional, Sequence, Tuple, TypedDict
13+
14+
import numpy as np
15+
from pie_modules.annotations import Label
16+
from pie_modules.documents import TextBasedDocument
17+
from pytorch_ie import AnnotationLayer, TaskEncoding, TaskModule, annotation_field
18+
from typing_extensions import TypeAlias
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
@dataclasses.dataclass
24+
class TestDocumentWithLabel(TextBasedDocument):
25+
label: AnnotationLayer[Label] = annotation_field()
26+
27+
28+
class TaskOutput(TypedDict, total=False):
29+
label: str
30+
probability: float
31+
32+
33+
# Define task specific input and output types
34+
DocumentType: TypeAlias = TestDocumentWithLabel
35+
InputEncodingType: TypeAlias = List[int]
36+
TargetEncodingType: TypeAlias = int
37+
ModelInputType = List[List[int]]
38+
ModelTargetType = List[int]
39+
ModelEncodingType: TypeAlias = Tuple[
40+
ModelInputType,
41+
Optional[ModelTargetType],
42+
]
43+
ModelOutputType = Dict[str, List[List[float]]]
44+
TaskOutputType: TypeAlias = TaskOutput
45+
46+
# This should be the same for all taskmodules
47+
TaskEncodingType: TypeAlias = TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType]
48+
TaskModuleType: TypeAlias = TaskModule[
49+
DocumentType,
50+
InputEncodingType,
51+
TargetEncodingType,
52+
ModelEncodingType,
53+
ModelOutputType,
54+
TaskOutputType,
55+
]
56+
57+
58+
def softmax(scores: List[float]) -> List[float]:
59+
"""Compute the softmax of a list of scores."""
60+
max_score = max(scores)
61+
exp_scores = [np.exp(score - max_score) for score in scores]
62+
sum_exp_scores = sum(exp_scores)
63+
return [score / sum_exp_scores for score in exp_scores]
64+
65+
66+
def argmax(scores: List[float]) -> int:
67+
"""Get the index of the maximum score."""
68+
max_index = 0
69+
max_value = scores[0]
70+
for i, score in enumerate(scores):
71+
if score > max_value:
72+
max_value = score
73+
max_index = i
74+
return max_index
75+
76+
77+
@TaskModule.register()
78+
class TestTaskModule(TaskModuleType):
79+
# If these attributes are set, the taskmodule is considered as prepared. They should be calculated
80+
# within _prepare() and are dumped automatically when saving the taskmodule with save_pretrained().
81+
PREPARED_ATTRIBUTES = ["labels"]
82+
DOCUMENT_TYPE = TestDocumentWithLabel
83+
84+
def __init__(
85+
self,
86+
labels: Optional[List[str]] = None,
87+
**kwargs,
88+
) -> None:
89+
# Important: Remaining keyword arguments need to be passed to super.
90+
super().__init__(**kwargs)
91+
# Save all passed arguments. They will be available via self._config().
92+
self.save_hyperparameters()
93+
94+
self.labels = labels
95+
self.token2id = {"PAD": 0}
96+
self.id2token = {0: "PAD"}
97+
98+
def _prepare(self, documents: Sequence[DocumentType]) -> None:
99+
"""Prepare the task module with training documents, e.g. collect all possible labels.
100+
101+
This method needs to set all attributes listed in PREPARED_ATTRIBUTES.
102+
"""
103+
104+
# create the label-to-id mapping
105+
labels = set()
106+
for document in documents:
107+
# all annotations of a document are hold in list like containers,
108+
# so we have to take its first element
109+
label_annotation = document.label[0]
110+
labels.add(label_annotation.label)
111+
112+
self.labels = sorted(labels)
113+
114+
def _post_prepare(self):
115+
"""Any further preparation logic that requires the result of _prepare().
116+
117+
But its result is not serialized with the taskmodule.
118+
"""
119+
# create the mapping, but spare the first index for the "O" (outside) class
120+
self.label_to_id = {label: i + 1 for i, label in enumerate(self.labels)}
121+
self.label_to_id["O"] = 0
122+
self.id_to_label = {v: k for k, v in self.label_to_id.items()}
123+
124+
def tokenize(self, text: str) -> List[int]:
125+
"""Tokenize the input text using the tokenizer."""
126+
# Tokenize the input text via whitespace
127+
tokens = text.split(" ")
128+
ids = []
129+
for token in tokens:
130+
# If the token is not already in the vocabulary, add it
131+
if token not in self.token2id:
132+
self.token2id[token] = len(self.token2id)
133+
ids.append(self.token2id[token])
134+
return ids
135+
136+
def token_ids2tokens(self, ids: List[int]) -> List[str]:
137+
"""Convert token ids back to tokens."""
138+
if len(self.id2token) != len(self.token2id):
139+
self.id2token = {v: k for k, v in self.token2id.items()}
140+
141+
return [self.id2token[id] for id in ids]
142+
143+
def encode_input(
144+
self,
145+
document: DocumentType,
146+
) -> TaskEncodingType:
147+
"""Create one or multiple task encodings for the given document."""
148+
149+
# tokenize the input text, this will be the input
150+
inputs = self.tokenize(document.text)
151+
152+
return TaskEncoding(
153+
document=document,
154+
inputs=inputs,
155+
)
156+
157+
def encode_target(
158+
self,
159+
task_encoding: TaskEncodingType,
160+
) -> TargetEncodingType:
161+
"""Create a target for a task encoding.
162+
163+
This may use any annotations of the underlying document.
164+
"""
165+
166+
# as above, all annotations are hold in lists, so we have to take its first element
167+
label_annotation = task_encoding.document.label[0]
168+
# translate the textual label to the target id
169+
if self.label_to_id is None:
170+
raise ValueError(
171+
"Task module is not prepared. Call prepare() or post_prepare() first."
172+
)
173+
return self.label_to_id[label_annotation.label]
174+
175+
def collate(self, task_encodings: Sequence[TaskEncodingType]) -> ModelEncodingType:
176+
"""Convert a list of task encodings to a batch that will be passed to the model."""
177+
# get the inputs from the task encodings
178+
inputs = [task_encoding.inputs for task_encoding in task_encodings]
179+
180+
if task_encodings[0].has_targets:
181+
# get the targets (label ids) from the task encodings
182+
targets = [task_encoding.targets for task_encoding in task_encodings]
183+
else:
184+
# during inference, we do not have any targets
185+
targets = None
186+
187+
return inputs, targets
188+
189+
def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputType]:
190+
"""Convert one model output batch to a sequence of taskmodule outputs."""
191+
192+
# get the logits from the model output
193+
logits = model_output["logits"]
194+
195+
# convert the logits to "probabilities"
196+
probabilities = [softmax(scores) for scores in logits]
197+
198+
# get the max class index per example
199+
max_label_ids = [argmax(probs) for probs in probabilities]
200+
201+
outputs = []
202+
for idx, label_id in enumerate(max_label_ids):
203+
# translate the label id back to the label text
204+
label = self.id_to_label[label_id]
205+
# get the probability and convert from tensor value to python float
206+
prob = round(float(probabilities[idx][label_id]), 4)
207+
# we create TransformerTextClassificationTaskOutput primarily for typing purposes,
208+
# a simple dict would also work
209+
result: TaskOutput = {
210+
"label": label,
211+
"probability": prob,
212+
}
213+
outputs.append(result)
214+
215+
return outputs
216+
217+
def create_annotations_from_output(
218+
self,
219+
task_encodings: TaskEncodingType,
220+
task_outputs: TaskOutputType,
221+
) -> Iterator[Tuple[str, Label]]:
222+
"""Convert a task output to annotations.
223+
224+
The method has to yield tuples (annotation_name, annotation).
225+
"""
226+
227+
# just yield a single annotation (other tasks may need multiple annotations per task output)
228+
yield "label", Label(label=task_outputs["label"], score=task_outputs["probability"])

0 commit comments

Comments
 (0)