diff --git a/setup.py b/setup.py index 232ff91..4f1c67d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import find_packages, setup REQUIRED_PKGS = [ - "pytorch-ie>=0.16.0,<1.0.0", + "pytorch-ie>=0.17.0,<1.0.0", "asciidag", # for visualization ] diff --git a/src/pie_utils/__init__.py b/src/pie_utils/__init__.py index ccfe772..e69de29 100644 --- a/src/pie_utils/__init__.py +++ b/src/pie_utils/__init__.py @@ -1 +0,0 @@ -from .dataset_dict import DatasetDict diff --git a/src/pie_utils/dataset_dict.py b/src/pie_utils/dataset_dict.py deleted file mode 100644 index cc685ef..0000000 --- a/src/pie_utils/dataset_dict.py +++ /dev/null @@ -1,318 +0,0 @@ -import json -import logging -import os -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, SupportsIndex, Type, Union - -import datasets -from pytorch_ie import Dataset, IterableDataset -from pytorch_ie.core import Document - -from pie_utils.document.processors.common import ( - EnterDatasetDictMixin, - EnterDatasetMixin, - ExitDatasetDictMixin, - ExitDatasetMixin, -) -from pie_utils.hydra import resolve_target - -logger = logging.getLogger(__name__) - - -def get_pie_dataset_type( - hf_dataset: Union[datasets.Dataset, datasets.IterableDataset] -) -> Union[Type[Dataset], Type[IterableDataset]]: - if isinstance(hf_dataset, datasets.Dataset): - return Dataset - elif isinstance(hf_dataset, datasets.IterableDataset): - return IterableDataset - else: - raise ValueError( - f"dataset_split must be of type Dataset or IterableDataset, but is {type(hf_dataset)}" - ) - - -class DatasetDict(datasets.DatasetDict): - def __getitem__(self, k) -> Dataset: - return super().__getitem__(k) - - # @classmethod - # def load_dataset(cls, *args, **kwargs) -> "DatasetDict": - # return cls(datasets.load_dataset(*args, **kwargs)) - - @classmethod - def from_hf_dataset( - cls, hf_dataset: datasets.DatasetDict, document_type: Union[str, Type[Document]] - ) -> "DatasetDict": - doc_type = resolve_target(document_type) - res = cls( - { - k: get_pie_dataset_type(v).from_hf_dataset(v, document_type=doc_type) - for k, v in hf_dataset.items() - } - ) - return res - - @classmethod - def from_json( - cls, - document_type: Union[Type[Document], str], - **kwargs, - ) -> "DatasetDict": - return cls.from_hf_dataset( - datasets.load_dataset("json", **kwargs), document_type=document_type - ) - - def to_json(self, path: Union[str, Path], **kwargs) -> None: - path = Path(path) - for split, dataset in self.items(): - split_path = path / split - logger.info(f'serialize documents to "{split_path}" ...') - os.makedirs(split_path, exist_ok=True) - file_name = split_path / "documents.jsonl" - with open(file_name, "w") as f: - for doc in dataset: - f.write(json.dumps(doc.asdict(), **kwargs) + "\n") - - @property - def document_type(self) -> Type[Document]: - """Returns the document type of the dataset. - - If there are no splits in the dataset, returns None. Raises an error if the dataset splits - have different document types. - """ - - if len(self) == 0: - raise ValueError("dataset does not contain any splits, cannot determine document type") - document_types = {ds.document_type for ds in self.values()} - if len(document_types) > 1: - raise ValueError( - f"dataset contains splits with different document types: {document_types}" - ) - return next(iter(document_types)) - - @property - def dataset_type(self) -> Union[Type[Dataset], Type[IterableDataset]]: - """Returns the dataset type of the dataset. - - If there are no splits in the dataset, returns None. Raises an error if the dataset splits - have different dataset types. - """ - - if len(self) == 0: - raise ValueError( - "dataset does not contain any splits, cannot determine the dataset type" - ) - dataset_types = {type(ds) for ds in self.values()} - if len(dataset_types) > 1: - raise ValueError( - f"dataset contains splits with different dataset types: {dataset_types}" - ) - return next(iter(dataset_types)) - - def map( - self, - function: Optional[Union[Callable, str]] = None, - result_document_type: Optional[Union[str, Type[Document]]] = None, - **kwargs, - ) -> "DatasetDict": - if function is not None: - func = resolve_target(function) - else: - - def identity(x): - # exclude from coverage because its usage happens in the map which is not collected - return x # pragma: no cover - - func = identity - map_kwargs = dict(function=func, **kwargs) - if result_document_type is not None: - map_kwargs["result_document_type"] = resolve_target(result_document_type) - - if isinstance(func, EnterDatasetDictMixin): - func.enter_dataset_dict(self) - - result_dict = {} - for split, dataset in self.items(): - if isinstance(func, EnterDatasetMixin): - func.enter_dataset(dataset=dataset, name=split) - result_dict[split] = dataset.map(**map_kwargs) - if isinstance(func, ExitDatasetMixin): - func.exit_dataset(dataset=result_dict[split], name=split) - - result = type(self)(result_dict) - - if isinstance(func, ExitDatasetDictMixin): - func.exit_dataset_dict(result) - - return result - - def select( - self, - split: str, - start: Optional[SupportsIndex] = None, - stop: Optional[SupportsIndex] = None, - step: Optional[SupportsIndex] = None, - **kwargs, - ) -> "DatasetDict": - if stop is not None: - range_args = [stop] - if start is not None: - range_args = [start] + range_args - if step is not None: - range_args = range_args + [step] - kwargs["indices"] = range(*range_args) - - if "indices" in kwargs: - result = type(self)(self) - pie_split = result[split] - result[split] = Dataset.from_hf_dataset( - dataset=pie_split.select(**kwargs), document_type=pie_split.document_type - ) - return result - else: - if len(kwargs) > 0: - logger.warning( - f"arguments for dataset.select() available, but they do not contain 'indices' which is required, " - f"so we do not call select. provided arguments: \n{json.dumps(kwargs, indent=2)}" - ) - return self - - def rename_splits( - self, - mapping: Optional[Dict[str, str]] = None, - keep_other_splits: bool = True, - ) -> "DatasetDict": - if mapping is None: - mapping = {} - result = type(self)( - { - mapping.get(name, name): data - for name, data in self.items() - if name in mapping or keep_other_splits - } - ) - return result - - def add_test_split( - self, - source_split: str = "train", - target_split: str = "test", - **kwargs, - ) -> "DatasetDict": - split_result_hf = self[source_split].train_test_split(**kwargs) - split_result = type(self)( - { - name: Dataset.from_hf_dataset(ds, document_type=self[source_split].document_type) - for name, ds in split_result_hf.items() - } - ) - res = type(self)(self) - res[source_split] = split_result["train"] - res[target_split] = split_result["test"] - split_sizes = {k: len(v) for k, v in res.items()} - logger.info(f"dataset size after adding the split: {split_sizes}") - return res - - def drop_splits(self, split_names: List[str]) -> "DatasetDict": - result = type(self)({name: ds for name, ds in self.items() if name not in split_names}) - return result - - def concat_splits(self, splits: List[str], target: str) -> "DatasetDict": - result = type(self)({name: ds for name, ds in self.items() if name not in splits}) - splits_to_concat = [self[name] for name in splits] - if len(splits_to_concat) == 0: - raise ValueError("please provide at least one split to concatenate") - - concatenated = datasets.concatenate_datasets(splits_to_concat) - result[target] = self.dataset_type.from_hf_dataset( - concatenated, document_type=self.document_type - ) - split_sizes = {k: len(v) for k, v in result.items()} - logger.info(f"dataset size after concatenating splits: {split_sizes}") - return result - - def filter( - self, - split: str, - function: Optional[Union[Callable, str]] = None, - result_split_name: Optional[str] = None, - **kwargs, - ) -> "DatasetDict": - if function is not None: - # create a shallow copy to not modify the input - result = type(self)(self) - function = resolve_target(function) - pie_split = result[split] - # TODO: Implement pytorch_ie.Dataset.filter() in a similar way such as map() to make use of the - # document type. For now, the filter function is called directly on the HF dataset and thus needs to - # accept a dict as input. - # we need to convert the dataset back to HF because the filter function internally uses map() which will - # break if the PIE variant is used - if isinstance(pie_split, Dataset): - hf_split = datasets.Dataset(**Dataset.get_base_kwargs(pie_split)) - elif isinstance(pie_split, IterableDataset): - hf_split = datasets.IterableDataset(**IterableDataset.get_base_kwargs(pie_split)) - else: - raise ValueError(f"dataset split has unknown type: {type(pie_split)}") - hf_split_filtered = hf_split.filter(function=function, **kwargs) - target_split_name = result_split_name or split - result[target_split_name] = type(pie_split).from_hf_dataset( - dataset=hf_split_filtered, document_type=pie_split.document_type - ) - # iterable datasets do not have a length - if not isinstance(result[target_split_name], IterableDataset): - logger.info( - f"filtered split [{target_split_name}] has {len(result[target_split_name])} entries" - ) - return result - else: - return self - - def move_to_new_split( - self, - ids: Optional[List[str]] = None, - filter_function: Optional[Union[Callable[[Dict[str, Any]], bool], str]] = None, - source_split: str = "train", - target_split: str = "test", - ) -> "DatasetDict": - if filter_function is not None: - filter_func = resolve_target(filter_function) - else: - if ids is None: - raise ValueError("please provide either a list of ids or a filter function") - - ids_set = set(ids) - - def filter_with_ids(ex: Dict[str, Any]): - # exclude from coverage because its usage happens in the map which is not collected - return ex["id"] in ids_set # pragma: no cover - - filter_func = filter_with_ids - - dataset_with_only_ids = self.filter( - split=source_split, - function=filter_func, - ) - dataset_without_ids = self.filter( - split=source_split, - function=lambda ex: not filter_func(ex), - ) - dataset_without_ids[target_split] = dataset_with_only_ids[source_split] - - split_sizes = {k: len(v) for k, v in dataset_without_ids.items()} - logger.info(f"dataset size after moving to new split: {split_sizes}") - return dataset_without_ids - - def cast_document_type( - self, new_document_type: Union[Type[Document], str], **kwargs - ) -> "DatasetDict": - new_type = resolve_target(new_document_type) - - result = type(self)( - { - name: ds.cast_document_type(new_document_type=new_type, **kwargs) - for name, ds in self.items() - } - ) - return result diff --git a/src/pie_utils/document/processors/candidate_relation_adder.py b/src/pie_utils/document/processors/candidate_relation_adder.py index fa0f472..bbb862a 100644 --- a/src/pie_utils/document/processors/candidate_relation_adder.py +++ b/src/pie_utils/document/processors/candidate_relation_adder.py @@ -9,7 +9,6 @@ from pytorch_ie import Dataset, IterableDataset from pytorch_ie.annotations import BinaryRelation from pytorch_ie.core import AnnotationList, Document -from pytorch_ie.core.document import BaseAnnotationList from pytorch_ie.utils.span import is_contained_in from pie_utils.document.processors.common import EnterDatasetMixin, ExitDatasetMixin @@ -20,14 +19,6 @@ D = TypeVar("D", bound=Document) -def target_layers(layer: BaseAnnotationList) -> dict[str, AnnotationList]: - return { - target_layer_name: layer._document[target_layer_name] - for target_layer_name in layer._targets - if target_layer_name in layer._document - } - - class CandidateRelationAdder(EnterDatasetMixin, ExitDatasetMixin): """CandidateRelationAdder adds binary relations to a document based on various parameters. It goes through combinations of available entity pairs as possible candidates for new relations. @@ -149,13 +140,8 @@ def __call__(self, document: D) -> D: available_partitions = document[self.partition_layer] else: available_partitions = [None] - rel_target_layers = target_layers(layer=rel_layer) - if not len(rel_target_layers) == 1: - raise ValueError( - f"Relation layer must have exactly one target layer but found the following target layers: " - f"{list(rel_target_layers)}" - ) - entity_layer = list(rel_target_layers.values())[0] + + entity_layer = rel_layer.target_layer if self.use_predictions: entity_layer = entity_layer.predictions diff --git a/src/pie_utils/hydra.py b/src/pie_utils/hydra.py deleted file mode 100644 index c648cd0..0000000 --- a/src/pie_utils/hydra.py +++ /dev/null @@ -1,88 +0,0 @@ -# taken from hydra/_internal/instantiate/_instantiate2.py -from typing import Any, Callable, Union - - -class HydraException(Exception): - ... - - -class CompactHydraException(HydraException): - ... - - -class InstantiationException(CompactHydraException): - ... - - -def _locate(path: str) -> Any: - """Locate an object by name or dotted path, importing as necessary. - - This is similar to the pydoc function `locate`, except that it checks for - the module from the given path from back to front. - """ - if path == "": - raise ImportError("Empty path") - from importlib import import_module - from types import ModuleType - - parts = [part for part in path.split(".")] - for part in parts: - if not len(part): - raise ValueError( - f"Error loading '{path}': invalid dotstring." - + "\nRelative imports are not supported." - ) - assert len(parts) > 0 - part0 = parts[0] - try: - obj = import_module(part0) - except Exception as exc_import: - raise ImportError( - f"Error loading '{path}':\n{repr(exc_import)}" - + f"\nAre you sure that module '{part0}' is installed?" - ) from exc_import - for m in range(1, len(parts)): - part = parts[m] - try: - obj = getattr(obj, part) - except AttributeError as exc_attr: - parent_dotpath = ".".join(parts[:m]) - if isinstance(obj, ModuleType): - mod = ".".join(parts[: m + 1]) - try: - obj = import_module(mod) - continue - except ModuleNotFoundError as exc_import: - raise ImportError( - f"Error loading '{path}':\n{repr(exc_import)}" - + f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?" - ) from exc_import - except Exception as exc_import: - raise ImportError( - f"Error loading '{path}':\n{repr(exc_import)}" - ) from exc_import - raise ImportError( - f"Error loading '{path}':\n{repr(exc_attr)}" - + f"\nAre you sure that '{part}' is an attribute of '{parent_dotpath}'?" - ) from exc_attr - return obj - - -def resolve_target( - target: Union[str, type, Callable[..., Any]], full_key: str = "" -) -> Union[type, Callable[..., Any]]: - """Resolve target string, type or callable into type or callable.""" - if isinstance(target, str): - try: - target = _locate(target) - except Exception as e: - msg = f"Error locating target '{target}', set env var HYDRA_FULL_ERROR=1 to see chained exception." - if full_key: - msg += f"\nfull_key: {full_key}" - raise InstantiationException(msg) from e - if not callable(target): - msg = f"Expected a callable target, got '{target}' of type '{type(target).__name__}'" - if full_key: - msg += f"\nfull_key: {full_key}" - raise InstantiationException(msg) - return target diff --git a/tests/document/processors/test_candidate_relation_adder.py b/tests/document/processors/test_candidate_relation_adder.py index 6b45ef4..dfc6368 100644 --- a/tests/document/processors/test_candidate_relation_adder.py +++ b/tests/document/processors/test_candidate_relation_adder.py @@ -8,9 +8,7 @@ from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextDocument -from pie_utils import DatasetDict from pie_utils.document.processors import CandidateRelationAdder -from tests import FIXTURES_ROOT from tests.document.processors.common import DocumentWithEntitiesRelationsAndPartitions @@ -163,15 +161,7 @@ def test_candidate_relation_adder_use_predictions(): assert relation.label == "no_relation" -@pytest.fixture(scope="module") -def dataset_dict(): - return DatasetDict.from_json( - data_dir=FIXTURES_ROOT / "dataset_dict" / "conll2003_extract", - document_type=DocumentWithEntitiesRelationsAndPartitions, - ) - - -def test_candidate_relation_adder_with_statistics(document1, dataset_dict, caplog): +def test_candidate_relation_adder_with_statistics(document1, caplog): candidate_relation_adder_with_statistics = CandidateRelationAdder( label="no_relation", collect_statistics=True, diff --git a/tests/fixtures/dataset_dict/conll2003_extract/test/documents.jsonl b/tests/fixtures/dataset_dict/conll2003_extract/test/documents.jsonl deleted file mode 100644 index cab3b47..0000000 --- a/tests/fixtures/dataset_dict/conll2003_extract/test/documents.jsonl +++ /dev/null @@ -1,3 +0,0 @@ -{"text": "SOCCER - JAPAN GET LUCKY WIN , CHINA IN SURPRISE DEFEAT .", "id": "0", "metadata": null, "entities": {"annotations": [{"start": 9, "end": 14, "label": "LOC", "score": 1.0, "_id": 5748529920044125636}, {"start": 31, "end": 36, "label": "PER", "score": 1.0, "_id": -8589846503006405843}], "predictions": []}} -{"text": "Nadim Ladki", "id": "1", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 11, "label": "PER", "score": 1.0, "_id": 6213877346071781521}], "predictions": []}} -{"text": "AL-AIN , United Arab Emirates 1996-12-06", "id": "2", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 6, "label": "LOC", "score": 1.0, "_id": 5985981094986103334}, {"start": 9, "end": 29, "label": "LOC", "score": 1.0, "_id": 3857557082585004468}], "predictions": []}} diff --git a/tests/fixtures/dataset_dict/conll2003_extract/train/documents.jsonl b/tests/fixtures/dataset_dict/conll2003_extract/train/documents.jsonl deleted file mode 100644 index ad49b68..0000000 --- a/tests/fixtures/dataset_dict/conll2003_extract/train/documents.jsonl +++ /dev/null @@ -1,3 +0,0 @@ -{"text": "EU rejects German call to boycott British lamb .", "id": "0", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 2, "label": "ORG", "score": 1.0, "_id": 39952992052351924}, {"start": 11, "end": 17, "label": "MISC", "score": 1.0, "_id": 763975142086013}, {"start": 34, "end": 41, "label": "MISC", "score": 1.0, "_id": -7419908469109575363}], "predictions": []}} -{"text": "Peter Blackburn", "id": "1", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 15, "label": "PER", "score": 1.0, "_id": -2336977660267910410}], "predictions": []}} -{"text": "BRUSSELS 1996-08-22", "id": "2", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 8, "label": "LOC", "score": 1.0, "_id": 5259325350015426157}], "predictions": []}} diff --git a/tests/fixtures/dataset_dict/conll2003_extract/validation/documents.jsonl b/tests/fixtures/dataset_dict/conll2003_extract/validation/documents.jsonl deleted file mode 100644 index 3382b03..0000000 --- a/tests/fixtures/dataset_dict/conll2003_extract/validation/documents.jsonl +++ /dev/null @@ -1,3 +0,0 @@ -{"text": "CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .", "id": "0", "metadata": null, "entities": {"annotations": [{"start": 10, "end": 24, "label": "ORG", "score": 1.0, "_id": 5570834040261376103}], "predictions": []}} -{"text": "LONDON 1996-08-30", "id": "1", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 6, "label": "LOC", "score": 1.0, "_id": 5985981094986103334}], "predictions": []}} -{"text": "West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship .", "id": "2", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 11, "label": "MISC", "score": 1.0, "_id": -527926364977337414}, {"start": 24, "end": 36, "label": "PER", "score": 1.0, "_id": -5723900879336490249}, {"start": 67, "end": 81, "label": "ORG", "score": 1.0, "_id": -6424668215920385847}, {"start": 87, "end": 95, "label": "ORG", "score": 1.0, "_id": 6014708185904075358}], "predictions": []}} diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py deleted file mode 100644 index 5ee22bc..0000000 --- a/tests/test_dataset_dict.py +++ /dev/null @@ -1,437 +0,0 @@ -import logging -from dataclasses import dataclass -from pathlib import Path -from typing import Iterable, Optional, Union - -import datasets -import pytest -from pytorch_ie import Dataset, IterableDataset -from pytorch_ie.annotations import LabeledSpan -from pytorch_ie.core import AnnotationList, Document, annotation_field -from pytorch_ie.documents import TextBasedDocument - -from pie_utils import DatasetDict -from pie_utils.dataset_dict import get_pie_dataset_type -from pie_utils.document.processors.common import ( - EnterDatasetDictMixin, - EnterDatasetMixin, - ExitDatasetDictMixin, - ExitDatasetMixin, -) -from tests import FIXTURES_ROOT - -logger = logging.getLogger(__name__) - -DATA_PATH = FIXTURES_ROOT / "dataset_dict" / "conll2003_extract" - - -@pytest.fixture(scope="module") -def dataset(): - return datasets.load_dataset("pie/conll2003") - - -@pytest.mark.skip(reason="don't create fixture data again") -def test_create_fixture_data(): - conll2003 = DatasetDict.load_dataset("pie/conll2003") - for split in list(conll2003): - # restrict all splits to 3 examples - conll2003 = conll2003.select(split=split, stop=3) - conll2003.to_json(DATA_PATH) - - -@dataclass -class DocumentWithEntitiesAndRelations(TextBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - - -@pytest.fixture(scope="module") -def dataset_dict(): - return DatasetDict.from_json( - data_dir=DATA_PATH, document_type=DocumentWithEntitiesAndRelations - ) - - -def test_from_json(dataset_dict): - assert set(dataset_dict) == {"train", "test", "validation"} - assert len(dataset_dict["train"]) == 3 - assert len(dataset_dict["test"]) == 3 - assert len(dataset_dict["validation"]) == 3 - - -@pytest.fixture(scope="module") -def iterable_dataset_dict(): - return DatasetDict.from_json( - data_dir=DATA_PATH, - document_type=DocumentWithEntitiesAndRelations, - streaming=True, - ) - - -def test_iterable_dataset_dict(iterable_dataset_dict): - assert set(iterable_dataset_dict) == {"train", "test", "validation"} - - -def test_to_json_and_back(dataset_dict, tmp_path): - path = Path(tmp_path) / "dataset_dict" - dataset_dict.to_json(path) - dataset_dict_from_json = DatasetDict.from_json( - data_dir=path, - document_type=dataset_dict.document_type, - ) - assert set(dataset_dict_from_json) == set(dataset_dict) - for split in dataset_dict: - assert len(dataset_dict_from_json[split]) == len(dataset_dict[split]) - for doc1, doc2 in zip(dataset_dict_from_json[split], dataset_dict[split]): - assert doc1 == doc2 - - -def test_document_type_empty_no_splits(): - with pytest.raises(ValueError) as excinfo: - DatasetDict().document_type - assert ( - str(excinfo.value) - == "dataset does not contain any splits, cannot determine document type" - ) - - -def test_document_type_different_types(dataset_dict): - # load the example dataset as a different document type - dataset_dict_different_type = DatasetDict.from_json( - data_dir=DATA_PATH, - document_type=TextBasedDocument, - ) - assert dataset_dict_different_type.document_type is TextBasedDocument - # create a dataset dict with different document types for train and test splits - dataset_dict_different_types = DatasetDict( - { - "train": dataset_dict["train"], - "test": dataset_dict_different_type["test"], - } - ) - # accessing the document type should raise an error with the message that starts with - # "dataset contains splits with different document types:" - with pytest.raises(ValueError) as excinfo: - dataset_dict_different_types.document_type - assert str(excinfo.value).startswith( - "dataset contains splits with different document types:" - ) - - -def test_dataset_type(dataset_dict): - assert dataset_dict.dataset_type is Dataset - - -def test_dataset_type_no_splits(): - with pytest.raises(ValueError) as excinfo: - DatasetDict().dataset_type - assert ( - excinfo.value == "dataset does not contain any splits, cannot determine dataset type" - ) - - -def test_dataset_type_different_type(dataset_dict, iterable_dataset_dict): - dataset_dict_different_type = DatasetDict( - { - "train": dataset_dict["train"], - "test": iterable_dataset_dict["test"], - } - ) - with pytest.raises(ValueError) as excinfo: - dataset_dict_different_type.dataset_type - assert excinfo.value == "dataset contains splits with different dataset types" - - -def test_get_pie_dataset_type(): - hf_ds = datasets.load_dataset("json", data_dir=DATA_PATH, split="train") - assert get_pie_dataset_type(hf_ds) == Dataset - hf_ds_iterable = datasets.load_dataset( - "json", data_dir=DATA_PATH, split="train", streaming=True - ) - assert get_pie_dataset_type(hf_ds_iterable) == IterableDataset - with pytest.raises(ValueError): - get_pie_dataset_type("not a dataset") - - -def map_fn(doc): - doc.text = doc.text.upper() - return doc - - -@pytest.mark.parametrize( - "function", - [map_fn, "tests.test_dataset_dict.map_fn"], -) -def test_map(dataset_dict, function): - dataset_dict_mapped = dataset_dict.map(function) - for split in dataset_dict: - assert len(dataset_dict_mapped[split]) == len(dataset_dict[split]) - for doc1, doc2 in zip(dataset_dict_mapped[split], dataset_dict[split]): - assert doc1.text == doc2.text.upper() - - -def test_map_noop(dataset_dict): - dataset_dict_mapped = dataset_dict.map() - for split in dataset_dict: - assert len(dataset_dict_mapped[split]) == len(dataset_dict[split]) - for doc1, doc2 in zip(dataset_dict_mapped[split], dataset_dict[split]): - assert doc1 == doc2 - - -def test_map_with_result_document_type(dataset_dict): - dataset_dict_mapped = dataset_dict.map(result_document_type=TextBasedDocument) - for split in dataset_dict: - assert len(dataset_dict_mapped[split]) == len(dataset_dict[split]) - for doc1, doc2 in zip(dataset_dict_mapped[split], dataset_dict[split]): - assert isinstance(doc1, TextBasedDocument) - assert isinstance(doc2, DocumentWithEntitiesAndRelations) - assert doc1.text == doc2.text - - -def test_map_with_context_manager(dataset_dict): - class DocumentCounter( - EnterDatasetMixin, ExitDatasetMixin, EnterDatasetDictMixin, ExitDatasetDictMixin - ): - def reset_statistics(self): - self.number = 0 - - def __call__(self, doc): - self.number += 1 - return doc - - def enter_dataset( - self, dataset: Union[Dataset, IterableDataset], name: Optional[str] = None - ) -> None: - self.reset_statistics() - self.split = name - - def exit_dataset( - self, dataset: Union[Dataset, IterableDataset], name: Optional[str] = None - ) -> None: - self.all_docs[self.split] = self.number - - def enter_dataset_dict(self, dataset_dict: DatasetDict) -> None: - self.all_docs = {} - self.split = None - - def exit_dataset_dict(self, dataset_dict: DatasetDict) -> None: - logger.info(f"Number of documents per split: {self.all_docs}") - - document_counter = DocumentCounter() - # note that we need to disable caching here, otherwise the __call__ method may not be called for any dataset split - dataset_dict_mapped = dataset_dict.map(function=document_counter, load_from_cache_file=False) - assert document_counter.all_docs == {"train": 3, "test": 3, "validation": 3} - - # the document_counter should not have been modified the dataset - assert set(dataset_dict_mapped) == set(dataset_dict) - for split in dataset_dict: - assert len(dataset_dict_mapped[split]) == len(dataset_dict[split]) - for doc1, doc2 in zip(dataset_dict_mapped[split], dataset_dict[split]): - assert doc1 == doc2 - - -def test_select(dataset_dict): - # select documents by index - dataset_dict_selected = dataset_dict.select( - split="train", - indices=[0, 2], - ) - assert len(dataset_dict_selected["train"]) == 2 - assert dataset_dict_selected["train"][0] == dataset_dict["train"][0] - assert dataset_dict_selected["train"][1] == dataset_dict["train"][2] - - # select documents by range - dataset_dict_selected = dataset_dict.select( - split="train", - stop=2, - start=1, - step=1, - ) - assert len(dataset_dict_selected["train"]) == 1 - assert dataset_dict_selected["train"][0] == dataset_dict["train"][1] - - # calling with no arguments that do result in the creation of indices should return the same dataset, - # but will log a warning if other arguments (here "any_arg") are passed - dataset_dict_selected = dataset_dict.select(split="train", any_arg="ignored") - assert len(dataset_dict_selected["train"]) == len(dataset_dict["train"]) - assert dataset_dict_selected["train"][0] == dataset_dict["train"][0] - assert dataset_dict_selected["train"][1] == dataset_dict["train"][1] - assert dataset_dict_selected["train"][2] == dataset_dict["train"][2] - - -def test_rename_splits(dataset_dict): - mapping = { - "train": "train_renamed", - "test": "test_renamed", - "validation": "validation_renamed", - } - dataset_dict_renamed = dataset_dict.rename_splits(mapping) - assert set(dataset_dict_renamed) == set(mapping.values()) - for split in dataset_dict: - split_renamed = mapping[split] - assert len(dataset_dict_renamed[split_renamed]) == len(dataset_dict[split]) - for doc1, doc2 in zip(dataset_dict_renamed[split_renamed], dataset_dict[split]): - assert doc1 == doc2 - - -def test_rename_split_noop(dataset_dict): - dataset_dict_renamed = dataset_dict.rename_splits() - assert set(dataset_dict_renamed) == set(dataset_dict) - for split in dataset_dict: - assert len(dataset_dict_renamed[split]) == len(dataset_dict[split]) - for doc1, doc2 in zip(dataset_dict_renamed[split], dataset_dict[split]): - assert doc1 == doc2 - - -def assert_doc_lists_equal(docs: Iterable[Document], other_docs: Iterable[Document]): - assert all(doc1 == doc2 for doc1, doc2 in zip(docs, other_docs)) - - -def test_add_test_split(dataset_dict): - dataset_dict_with_test = dataset_dict.add_test_split( - source_split="test", target_split="new_test", test_size=1, shuffle=False - ) - assert "new_test" in dataset_dict_with_test - assert len(dataset_dict_with_test["new_test"]) + len(dataset_dict_with_test["test"]) == len( - dataset_dict["test"] - ) - assert len(dataset_dict_with_test["new_test"]) == 1 - assert len(dataset_dict_with_test["test"]) == 2 - assert_doc_lists_equal(dataset_dict_with_test["new_test"], dataset_dict["test"][2:]) - assert_doc_lists_equal(dataset_dict_with_test["test"], dataset_dict["test"][:2]) - test_ids = [doc.id for doc in dataset_dict_with_test["test"]] - new_test_ids = [doc.id for doc in dataset_dict_with_test["new_test"]] - assert set(test_ids).intersection(set(new_test_ids)) == set() - - # remaining splits should be unchanged - assert len(dataset_dict_with_test["train"]) == len(dataset_dict["train"]) - assert len(dataset_dict_with_test["validation"]) == len(dataset_dict["validation"]) - assert_doc_lists_equal(dataset_dict_with_test["train"], dataset_dict["train"]) - assert_doc_lists_equal(dataset_dict_with_test["validation"], dataset_dict["validation"]) - - -def test_drop_splits(dataset_dict): - dataset_dict_dropped = dataset_dict.drop_splits(["train", "validation"]) - assert set(dataset_dict_dropped) == {"test"} - assert len(dataset_dict_dropped["test"]) == len(dataset_dict["test"]) - assert_doc_lists_equal(dataset_dict_dropped["test"], dataset_dict["test"]) - - -def test_concat_splits(dataset_dict): - dataset_dict_concatenated = dataset_dict.concat_splits(["train", "validation"], target="train") - assert set(dataset_dict_concatenated) == {"test", "train"} - assert len(dataset_dict_concatenated["train"]) == len(dataset_dict["train"]) + len( - dataset_dict["validation"] - ) - assert_doc_lists_equal( - dataset_dict_concatenated["train"], - list(dataset_dict["train"]) + list(dataset_dict["validation"]), - ) - - -def test_concat_splits_no_splits(dataset_dict): - with pytest.raises(ValueError) as excinfo: - dataset_dict.concat_splits(splits=[], target="train") - assert excinfo.value == "please provide at least one split to concatenate" - - -def test_concat_splits_different_dataset_types(dataset_dict, iterable_dataset_dict): - dataset_dict_to_concat = DatasetDict( - { - "train": dataset_dict["train"], - "validation": iterable_dataset_dict["validation"], - } - ) - with pytest.raises(ValueError) as excinfo: - dataset_dict_to_concat.concat_splits(splits=["train", "validation"], target="train") - assert excinfo.value.startswith("dataset types of splits to concatenate differ:") - - -def test_filter(dataset_dict): - dataset_dict_filtered = dataset_dict.filter( - function=lambda doc: len(doc["text"]) > 15, - split="train", - ) - assert all(len(doc.text) > 15 for doc in dataset_dict_filtered["train"]) - assert len(dataset_dict["train"]) == 3 - assert len(dataset_dict_filtered["train"]) == 2 - assert dataset_dict_filtered["train"][0] == dataset_dict["train"][0] - assert dataset_dict_filtered["train"][1] == dataset_dict["train"][2] - - # remaining splits should be unchanged - assert len(dataset_dict_filtered["validation"]) == len(dataset_dict["validation"]) == 3 - assert len(dataset_dict_filtered["test"]) == len(dataset_dict["test"]) == 3 - assert_doc_lists_equal(dataset_dict_filtered["validation"], dataset_dict["validation"]) - assert_doc_lists_equal(dataset_dict_filtered["test"], dataset_dict["test"]) - - -def test_filter_iterable(iterable_dataset_dict): - dataset_dict_filtered = iterable_dataset_dict.filter( - function=lambda doc: len(doc["text"]) > 15, - split="train", - ) - docs_train = list(dataset_dict_filtered["train"]) - assert len(docs_train) == 2 - assert all(len(doc.text) > 15 for doc in docs_train) - - -def test_filter_unknown_dataset_type(): - dataset_dict = DatasetDict({"train": "foo"}) - with pytest.raises(ValueError) as excinfo: - dataset_dict.filter(function=lambda doc: True, split="train") - assert excinfo.value == "dataset split has unknown type: " - - -def test_filter_noop(dataset_dict): - # passing no filter function should be a noop - dataset_dict_filtered = dataset_dict.filter(split="train") - assert len(dataset_dict_filtered["train"]) == len(dataset_dict["train"]) == 3 - assert len(dataset_dict_filtered["validation"]) == len(dataset_dict["validation"]) == 3 - assert len(dataset_dict_filtered["test"]) == len(dataset_dict["test"]) == 3 - assert_doc_lists_equal(dataset_dict_filtered["train"], dataset_dict["train"]) - assert_doc_lists_equal(dataset_dict_filtered["validation"], dataset_dict["validation"]) - assert_doc_lists_equal(dataset_dict_filtered["test"], dataset_dict["test"]) - - -@pytest.mark.parametrize( - # we can either provide ids or a filter function - "ids,filter_function", - [ - (["1", "2"], None), - (None, lambda doc: doc["id"] in ["1", "2"]), - ], -) -def test_move_to_new_split(dataset_dict, ids, filter_function): - # move the second and third document from train to new_validation - dataset_dict_moved = dataset_dict.move_to_new_split( - ids=ids, - filter_function=filter_function, - source_split="train", - target_split="new_validation", - ) - assert len(dataset_dict_moved["train"]) == 1 - assert len(dataset_dict_moved["new_validation"]) == 2 - assert_doc_lists_equal(dataset_dict_moved["train"], dataset_dict["train"][:1]) - - # the remaining splits should be unchanged - assert len(dataset_dict_moved["validation"]) == len(dataset_dict["validation"]) == 3 - assert len(dataset_dict_moved["test"]) == len(dataset_dict["test"]) == 3 - assert_doc_lists_equal(dataset_dict_moved["validation"], dataset_dict["validation"]) - assert_doc_lists_equal(dataset_dict_moved["test"], dataset_dict["test"]) - - -def test_move_to_new_split_missing_arguments(dataset_dict): - with pytest.raises(ValueError) as excinfo: - dataset_dict.move_to_new_split( - ids=None, - filter_function=None, - source_split="train", - target_split="new_validation", - ) - assert excinfo.value == "please provide either a list of ids or a filter function" - - -def test_cast_document_type(dataset_dict): - dataset_dict_cast = dataset_dict.cast_document_type(TextBasedDocument) - assert dataset_dict_cast.document_type == TextBasedDocument - for split in dataset_dict_cast: - assert all(isinstance(doc, TextBasedDocument) for doc in dataset_dict_cast[split]) diff --git a/tests/test_hydra.py b/tests/test_hydra.py deleted file mode 100644 index 9fc8047..0000000 --- a/tests/test_hydra.py +++ /dev/null @@ -1,62 +0,0 @@ -from importlib import import_module - -import pytest - -from pie_utils.hydra import InstantiationException, resolve_target - - -def test_resolve_target_string(): - target_str = "pie_utils.hydra.resolve_target" - target = resolve_target(target_str) - assert target == resolve_target - - -def test_resolve_target_not_found(): - with pytest.raises(InstantiationException): - resolve_target("does.not.exist", full_key="full_key") - - -def test_resolve_target_empty_path(): - with pytest.raises(InstantiationException): - resolve_target("") - - -def test_resolve_target_empty_part(): - with pytest.raises(InstantiationException): - resolve_target("pie_utils..hydra.resolve_target") - - -def test_resolve_target_from_src(): - resolve_target("src.pie_utils.hydra.resolve_target") - - -def test_resolve_target_from_src_not_found(): - with pytest.raises(InstantiationException): - resolve_target("tests.fixtures.not_loadable") - - -def test_resolve_target_not_loadable(monkeypatch): - # Normally, import_module will raise ModuleNotFoundError, but we want to test the case - # in _locate where it raises a different exception. - # So we mock the import_module function to raise a different exception on the second call - # (the first call is important to succeed because otherwise we just check the first try/except block). - class MockImportModule: - def __init__(self): - self.counter = 0 - - def __call__(self, path): - if self.counter < 1: - self.counter += 1 - return import_module(path) - raise Exception("Custom exception") - - # Apply the monkeypatch to replace import_module with our mock function - monkeypatch.setattr("importlib.import_module", MockImportModule()) - - with pytest.raises(Exception): - resolve_target("src.invalid_attr") - - -def test_resolve_target_not_callable_with_full_key(): - with pytest.raises(InstantiationException): - resolve_target("pie_utils.hydra", full_key="full_key")