Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pytorch_ie/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .document import Annotation, AnnotationList, Document, annotation_field
from .metric import DocumentMetric
from .model import PyTorchIEModel
from .statistic import DocumentStatistic
from .taskmodule import TaskEncoding, TaskModule
105 changes: 105 additions & 0 deletions src/pytorch_ie/core/statistic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from abc import abstractmethod
from collections import defaultdict
from typing import Any, Dict, Generator, List, Tuple, Union

from pytorch_ie.core.document import Document
from pytorch_ie.core.metric import DocumentMetric


def _flatten_dict_gen(d, parent_key: Tuple[str, ...] = ()) -> Generator:
for k, v in d.items():
new_key = parent_key + (k,)
if isinstance(v, dict):
yield from dict(_flatten_dict_gen(v, new_key)).items()
else:
yield new_key, v


def flatten_dict(d: Dict[str, Any]) -> Dict[Tuple[str, ...], Any]:
return dict(_flatten_dict_gen(d))


def unflatten_dict(d: Dict[Tuple[str, ...], Any]) -> Union[Dict[str, Any], Any]:
"""Unflattens a dictionary with nested keys.

Example:
>>> d = {("a", "b", "c"): 1, ("a", "b", "d"): 2, ("a", "e"): 3}
>>> unflatten_dict(d)
{'a': {'b': {'c': 1, 'd': 2}, 'e': 3}}
"""
result: Dict[str, Any] = {}
for k, v in d.items():
if len(k) == 0:
if len(result) > 1:
raise ValueError("Cannot unflatten dictionary with multiple root keys.")
return v
current = result
for key in k[:-1]:
current = current.setdefault(key, {})
current[k[-1]] = v
return result


class DocumentStatistic(DocumentMetric):
"""A special type of metric that collects statistics from a document.

Usage:

```python
from transformers import AutoTokenizer
from pytorch_ie import DatasetDict
from pytorch_ie.core import Document, DocumentStatistic

class TokenCountCollector(DocumentStatistic):
def __init__(self, tokenizer_name_or_path: str, field: str, **kwargs):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
self.kwargs = kwargs
self.field = field

def _collect(self, doc: Document) -> int:
text = getattr(doc, self.field)
encodings = self.tokenizer(text, **self.kwargs)
tokens = encodings.tokens()
return len(tokens)

dataset = DatasetDict.load_dataset("pie/conll2003")
statistic = TokenCountCollector(tokenizer_name_or_path="bert-base-cased", field="text")
values = statistic(dataset)
assert values == {
'train': [12, 4, 11, 34, 39, 43, 27, 50, 45, 43, ...],
'validation': [37, 11, 40, 43, 44, 27, 35, 40, 48, 43, ...],
'test': [33, 8, 15, 29, 30, 56, 31, 19, 21, 30, ...],
}
```
"""

def reset(self) -> None:
self._values: List[Any] = []

@abstractmethod
def _collect(self, doc: Document) -> Any:
"""Collect any values from a document."""

def _update(self, document: Document) -> None:
values = self._collect(document)
self._values.append(values)

def _compute(self) -> Any:
"""We just integrate the values by creating lists for each leaf of the (nested)
dictionary."""
stats = defaultdict(list)
for collected_result in self._values:
if isinstance(collected_result, dict):
collected_result_flat = flatten_dict(collected_result)
for k, v in collected_result_flat.items():
if isinstance(v, list):
stats[k].extend(v)
else:
stats[k].append(v)
else:
if isinstance(collected_result, list):
stats[()].extend(collected_result)
else:
stats[()].append(collected_result)
return unflatten_dict(dict(stats))
106 changes: 106 additions & 0 deletions src/pytorch_ie/metrics/statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from collections import defaultdict
from typing import Dict, List

from transformers import AutoTokenizer

from pytorch_ie.core import Document, DocumentStatistic


class TokenCountCollector(DocumentStatistic):
"""Collects the token count of a field when tokenizing its content with a Huggingface tokenizer.

The field should be a string.
"""

def __init__(self, tokenizer_name_or_path: str, field: str, **kwargs):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
self.kwargs = kwargs
self.field = field

def _collect(self, doc: Document) -> int:
text = getattr(doc, self.field)
encodings = self.tokenizer(text, **self.kwargs)
tokens = encodings.tokens()
return len(tokens)


class FieldLengthCollector(DocumentStatistic):
"""Collects the length of a field, e.g. to collect the number the characters in the input text.

The field should be a list of sized elements.
"""

def __init__(self, field: str):
super().__init__()
self.field = field

def _collect(self, doc: Document) -> int:
field_obj = getattr(doc, self.field)
return len(field_obj)


class SubFieldLengthCollector(DocumentStatistic):
"""Collects the length of a subfield in a field, e.g. to collect the number of arguments of N-ary relations."""

def __init__(self, field: str, subfield: str):
super().__init__()
self.field = field
self.subfield = subfield

def _collect(self, doc: Document) -> List[int]:
field_obj = getattr(doc, self.field)
lengths = []
for entry in field_obj:
subfield_obj = getattr(entry, self.subfield)
lengths.append(len(subfield_obj))
return lengths


class LabeledSpanLengthCollector(DocumentStatistic):
"""Collects the length of spans in a field per label, e.g. to collect the length of entities per type.

The field should be a list of elements with a label, a start and end attribute.
"""

def __init__(self, field: str):
super().__init__()
self.field = field

def _collect(self, doc: Document) -> Dict[str, List[int]]:
field_obj = getattr(doc, self.field)
counts = defaultdict(list)
for elem in field_obj:
counts[elem.label].append(elem.end - elem.start)
return dict(counts)


class DummyCollector(DocumentStatistic):
"""A dummy collector that always returns 1, e.g. to count the number of documents.

Can be used to count the number of documents.
"""

def _collect(self, doc: Document) -> int:
return 1


class LabelCountCollector(DocumentStatistic):
"""Collects the number of field entries per label, e.g. to collect the number of entities per type.

The field should be a list of elements with a label attribute.

Important: To make correct use of the result data, missing values need to be filled with 0, e.g.:
{("ORG",): [2, 3], ("LOC",): [2]} -> {("ORG",): [2, 3], ("LOC",): [2, 0]}
"""

def __init__(self, field: str):
super().__init__()
self.field = field

def _collect(self, doc: Document) -> Dict[str, int]:
field_obj = getattr(doc, self.field)
counts: Dict[str, int] = defaultdict(lambda: 1)
for elem in field_obj:
counts[elem.label] += 1
return dict(counts)
104 changes: 104 additions & 0 deletions tests/core/test_statistic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import dataclasses

import pytest

from pytorch_ie import DatasetDict
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.core.statistic import flatten_dict
from pytorch_ie.documents import TextBasedDocument
from pytorch_ie.metrics.statistics import (
DummyCollector,
FieldLengthCollector,
LabelCountCollector,
LabeledSpanLengthCollector,
SubFieldLengthCollector,
TokenCountCollector,
)
from tests import FIXTURES_ROOT


@pytest.fixture
def dataset():
@dataclasses.dataclass
class Conll2003Document(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")

return DatasetDict.from_json(
data_dir=FIXTURES_ROOT / "dataset_dict" / "conll2003_extract",
document_type=Conll2003Document,
)


def test_prepare_data(dataset):
statistic = DummyCollector()
values_nested = statistic(dataset)
prepared_data = flatten_dict(values_nested)
assert prepared_data == {
("train",): [1, 1, 1],
("test",): [1, 1, 1],
("validation",): [1, 1, 1],
}
statistic = LabelCountCollector(field="entities")
values_nested = statistic(dataset)
prepared_data = flatten_dict(values_nested)
assert prepared_data == {
("train", "ORG"): [2],
("train", "MISC"): [3],
("train", "PER"): [2],
("train", "LOC"): [2],
("test", "LOC"): [2, 3],
("test", "PER"): [2, 2],
("validation", "ORG"): [2, 3],
("validation", "LOC"): [2],
("validation", "MISC"): [2],
("validation", "PER"): [2],
}
statistic = FieldLengthCollector(field="text")
values_nested = statistic(dataset)
prepared_data = flatten_dict(values_nested)
assert prepared_data == {
("train",): [48, 15, 19],
("test",): [57, 11, 40],
("validation",): [65, 17, 187],
}

statistic = LabeledSpanLengthCollector(field="entities")
values_nested = statistic(dataset)
prepared_data = flatten_dict(values_nested)
assert prepared_data == {
("train", "ORG"): [2],
("train", "MISC"): [6, 7],
("train", "PER"): [15],
("train", "LOC"): [8],
("test", "LOC"): [5, 6, 20],
("test", "PER"): [5, 11],
("validation", "ORG"): [14, 14, 8],
("validation", "LOC"): [6],
("validation", "MISC"): [11],
("validation", "PER"): [12],
}

# this is not super useful, we just collect teh lengths of the labels, but it is enough to test the code
statistic = SubFieldLengthCollector(field="entities", subfield="label")
values_nested = statistic(dataset)
prepared_data = flatten_dict(values_nested)
assert prepared_data == {
("train",): [3, 4, 4, 3, 3],
("test",): [3, 3, 3, 3, 3],
("validation",): [3, 3, 4, 3, 3, 3],
}


@pytest.mark.slow
def test_prepare_data_tokenize(dataset):
statistic = TokenCountCollector(
field="text", tokenizer_name_or_path="bert-base-uncased", add_special_tokens=False
)
values_nested = statistic(dataset)
prepared_data = flatten_dict(values_nested)
assert prepared_data == {
("train",): [9, 2, 6],
("test",): [12, 4, 12],
("validation",): [11, 6, 38],
}