Skip to content

Commit 51171a2

Browse files
authored
use AutoAnnotationPipeline (#216)
* use pie-core 0.3.1 * switch back to use `AutoAnnotationPipeline` instead of `PyTorchIEPipeline`; add comments to add `pipeline_type: pytorch-ie` for old models to pipeline configs * fix tests by setting `pipeline.pipeline_type: pytorch-ie` in train.yaml config * save the pipeline after training (instead of individual model and taskmodule) * during inference, raise an exception when trying to load a model checkpoint into a pipeline that is no PyTorchIEPipeline
1 parent b2aec47 commit 51171a2

File tree

8 files changed

+35
-23
lines changed

8 files changed

+35
-23
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
_target_: pytorch_ie.PyTorchIEPipeline.from_pretrained
1+
_target_: pie_core.AutoAnnotationPipeline.from_pretrained
22
pretrained_model_name_or_path: ???
33
show_progress_bar: true
4+
# uncomment for "old" PyTorch-IE models that do not have a pipeline_type key in their config.json
5+
# pipeline_type: pytorch-ie

configs/pipeline/ner_re_pipeline.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,9 @@ show_progress_bar: true
1010
device: -1
1111
ner_pipeline:
1212
batch_size: 1
13+
# uncomment for "old" PyTorch-IE models that do not have a pipeline_type key in their config.json
14+
# pipeline_type: pytorch-ie
1315
re_pipeline:
1416
batch_size: 1
17+
# uncomment for "old" PyTorch-IE models that do not have a pipeline_type key in their config.json
18+
# pipeline_type: pytorch-ie

configs/predict.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ name: "default"
2424
# or the url to huggingface hub where the taskmodule and model was pushed to.
2525
# It is used in the pipeline config.
2626
model_name_or_path: pie/example-ner-spanclf-conll03
27+
# required for "old" PyTorch-IE models that do not have a pipeline_type key in their config.json.
28+
# Saving a model with AnnotationPipeline.push_to_hub will automatically add this key to the config.json
29+
pipeline:
30+
pipeline_type: pytorch-ie
2731

2832
# to override model weights with content of a checkpoint
2933
ckpt_path: null

poetry.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
# --------- python-ie --------- #
1717
"pytorch-ie (>=0.33.0,<0.34.0)",
1818
"pie-datasets (>=0.11.0,<0.12.0)",
19+
"pie-core (>=0.3.1,<0.4.0)", # to use AutoAnnotationPipeline with old models, see https://github.com/ArneBinder/pie-core/pull/95
1920

2021
# ------- reprocessing -------- #
2122
"nltk (>=3.9.1,<4.0.0)", # sentence splitter (just for drugprot.yaml experiment which dry-runs in slow tests, remove if not needed)

src/pipeline/ner_re_pipeline.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
from functools import partial
55
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, TypeVar, Union
66

7+
from pie_core import AutoAnnotationPipeline, Document, WithDocumentTypeMixin
78
from pie_core.utils.hydra import resolve_type
8-
from pytorch_ie import PyTorchIEPipeline, WithDocumentTypeMixin
9-
from pytorch_ie.core import Document
109

1110
logger = logging.getLogger(__name__)
1211

@@ -166,7 +165,7 @@ def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequ
166165
layer_names=[self.entity_layer, self.relation_layer],
167166
**self.processor_kwargs.get("clear_annotations", {}),
168167
),
169-
"ner_pipeline": PyTorchIEPipeline.from_pretrained(
168+
"ner_pipeline": AutoAnnotationPipeline.from_pretrained(
170169
self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {})
171170
),
172171
"use_predicted_entities": partial(
@@ -181,7 +180,7 @@ def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequ
181180
# **self.processor_kwargs.get("create_candidate_relations", {})
182181
# ),
183182
# ),
184-
"re_pipeline": PyTorchIEPipeline.from_pretrained(
183+
"re_pipeline": AutoAnnotationPipeline.from_pretrained(
185184
self.re_model_path, **self.processor_kwargs.get("re_pipeline", {})
186185
),
187186
# otherwise we can not move the entities back to predictions

src/predict.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from omegaconf import DictConfig, OmegaConf
4242
from pie_core import AnnotationPipeline
4343
from pie_datasets import DatasetDict
44+
from pytorch_ie import PyTorchIEPipeline
4445
from pytorch_ie.models import * # noqa: F403
4546
from pytorch_ie.taskmodules import * # noqa: F403
4647

@@ -84,6 +85,10 @@ def predict(cfg: DictConfig) -> Tuple[dict, dict]:
8485
# However, ckpt_path can be used to load different weights from any checkpoint.
8586
if cfg.ckpt_path is not None:
8687
log.info(f"Loading model weights from checkpoint: {cfg.ckpt_path}")
88+
if not isinstance(pipeline, PyTorchIEPipeline):
89+
raise ValueError(
90+
"The pipeline has to be of type PyTorchIEPipeline to load a checkpoint."
91+
)
8792
pipeline.model = (
8893
type(pipeline.model)
8994
.load_from_checkpoint(checkpoint_path=cfg.ckpt_path)

src/train.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from pie_core import AnnotationPipeline, AutoModel, TaskModule
4343
from pie_core.utils.dictionary import flatten_dict_s
4444
from pie_datasets import DatasetDict
45-
from pytorch_ie import PieDataModule, PyTorchIEModel
45+
from pytorch_ie import PieDataModule, PyTorchIEModel, PyTorchIEPipeline
4646
from pytorch_ie.models import * # noqa: F403
4747
from pytorch_ie.models.interface import (
4848
RequiresModelNameOrPath,
@@ -197,14 +197,6 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
197197
log.info("Logging hyperparameters!")
198198
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)
199199

200-
if cfg.paths.model_save_dir is not None:
201-
log.info(f"Save taskmodule to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
202-
taskmodule.save_pretrained(
203-
save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub
204-
)
205-
else:
206-
log.warning("the taskmodule is not saved because no save_dir is specified")
207-
208200
if cfg.get("train"):
209201
# Set model in training mode (since pytorch-lightning 2.2.0 the model is not set
210202
# to train mode automatically in trainer.fit). To just partly train the model
@@ -225,19 +217,25 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
225217
checkpoint_dir=trainer.checkpoint_callback.dirpath,
226218
)
227219

220+
pipeline: Optional[AnnotationPipeline] = None
228221
if not cfg.trainer.get("fast_dev_run"):
229222
if cfg.paths.model_save_dir is not None:
230223
if best_ckpt_path == "":
231224
log.warning("Best ckpt not found! Using current weights for saving...")
232225
else:
233226
model = type(model).load_from_checkpoint(best_ckpt_path)
234227

235-
log.info(f"Save model to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
236-
model.save_pretrained(
228+
log.info(
229+
f"Save pipeline (model + taskmodule) to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]"
230+
)
231+
pipeline = PyTorchIEPipeline(model=model, taskmodule=taskmodule)
232+
pipeline.save_pretrained(
237233
save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub
238234
)
239235
else:
240-
log.warning("the model is not saved because no save_dir is specified")
236+
log.warning(
237+
"the pipeline (model + taskmodule) is not saved because no save_dir is specified"
238+
)
241239

242240
if cfg.get("validate"):
243241
log.info("Starting validation!")
@@ -271,7 +269,6 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
271269
# This can be overridden by the `predict_split` config parameter.
272270
split = cfg.get("predict_split", datamodule.test_split)
273271
# Init the inference pipeline
274-
pipeline: Optional[AnnotationPipeline] = None
275272
if cfg.get("pipeline") and cfg.pipeline.get("_target_"):
276273
log.info(f"Instantiating inference pipeline <{cfg.pipeline._target_}>")
277274
pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial")

0 commit comments

Comments
 (0)