Skip to content

Commit 3f6fab0

Browse files
authored
register all models @pie_core.Model (#503)
This PR registers all models `@pie_core.Model` instead of `@PyTorchIEModel` to make them work with `pie_core.AutoAnnotationPipeline.from_config` in the future. Additional changes: - remove `AutoPyTorchIEModel` (just use `AutoModel` instead) - re-export `AutoModel` from `pie_core`. NOTE: Previously `AutoPyTorchIEModel` was exported as `AutoModel`! But if used correctly, it should not make a difference. - register `PyTorchIEPipeline` at name `"pytorch-ie"` - remove backwards compatibility: `Pipeline` and `AutoPipeline`(use `PyTorchIEPipeline` in either case) - adjust the `README.md` Note: It is not yet possible to load an "old" pipeline (e.g. [pie/example-ner-spanclf-conll03](https://huggingface.co/pie/example-ner-spanclf-conll03)) via `AutoAnnotationPipeline` since its config does not has a `pipeline_type` key. In the future, it should be possible to provide that as argument in the following manner, but **this still needs a minor adjustment in `pie_core.Auto.from_config` (TODO: reference issue/PR in pie-core)**: ``` pipeline = AutoAnnotationPipeline.from_pretrained("pie/example-ner-spanclf-conll03", pipeline_type="pytorch-ie") ```
1 parent 06fdad2 commit 3f6fab0

22 files changed

+59
-99
lines changed

README.md

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ for details.
318318
```python
319319
from dataclasses import dataclass
320320

321+
from pytorch_ie import PyTorchIEPipeline
321322
from pytorch_ie.annotations import LabeledSpan
322-
from pytorch_ie.auto import AutoPipeline
323323
from pytorch_ie.core import AnnotationLayer, annotation_field
324324
from pytorch_ie.documents import TextDocument
325325

@@ -334,7 +334,7 @@ document = ExampleDocument(
334334
)
335335

336336
# see below for the long version
337-
ner_pipeline = AutoPipeline.from_pretrained("pie/example-ner-spanclf-conll03", device=-1, num_workers=0)
337+
ner_pipeline = PyTorchIEPipeline.from_pretrained("pie/example-ner-spanclf-conll03", device=-1, num_workers=0)
338338

339339
ner_pipeline(document)
340340

@@ -349,7 +349,7 @@ for entity in document.entities.predictions:
349349

350350
<details>
351351
<summary>
352-
To create the same pipeline as above without `AutoPipeline`
352+
Under the hood, the following happens when calling `PyTorchIEPipeline.from_pretrained`
353353
</summary>
354354

355355
```python
@@ -364,23 +364,6 @@ ner_pipeline = PyTorchIEPipeline(model=ner_model, taskmodule=ner_taskmodule, dev
364364

365365
</details>
366366

367-
<details>
368-
<summary>
369-
Or, without `Auto` classes at all
370-
</summary>
371-
372-
```python
373-
from pytorch_ie.pipeline import PyTorchIEPipeline
374-
from pytorch_ie.models import TransformerSpanClassificationModel
375-
from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule
376-
377-
model_name_or_path = "pie/example-ner-spanclf-conll03"
378-
ner_taskmodule = TransformerSpanClassificationTaskModule.from_pretrained(model_name_or_path)
379-
ner_model = TransformerSpanClassificationModel.from_pretrained(model_name_or_path)
380-
ner_pipeline = PyTorchIEPipeline(model=ner_model, taskmodule=ner_taskmodule, device=-1, num_workers=0)
381-
```
382-
383-
</details>
384367
<details>
385368
<summary>
386369

@@ -391,8 +374,8 @@ ner_pipeline = PyTorchIEPipeline(model=ner_model, taskmodule=ner_taskmodule, dev
391374
```python
392375
from dataclasses import dataclass
393376

377+
from pytorch_ie import PyTorchIEPipeline
394378
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
395-
from pytorch_ie.auto import AutoPipeline
396379
from pytorch_ie.core import AnnotationLayer, annotation_field
397380
from pytorch_ie.documents import TextDocument
398381

@@ -407,7 +390,7 @@ document = ExampleDocument(
407390
"“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner at SOSV and managing director of IndieBio."
408391
)
409392

410-
re_pipeline = AutoPipeline.from_pretrained("pie/example-re-textclf-tacred", device=-1, num_workers=0)
393+
re_pipeline = PyTorchIEPipeline.from_pretrained("pie/example-re-textclf-tacred", device=-1, num_workers=0)
411394

412395
for start, end, label in [(65, 75, "PER"), (96, 100, "ORG"), (126, 134, "ORG")]:
413396
document.entities.append(LabeledSpan(start=start, end=end, label=label))

examples/predict/ner_span_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from pie_core import AnnotationLayer, annotation_field
44

5+
from pytorch_ie import PyTorchIEPipeline
56
from pytorch_ie.annotations import LabeledSpan
67
from pytorch_ie.documents import TextDocument
78
from pytorch_ie.models import TransformerSpanClassificationModel
8-
from pytorch_ie.pipeline import Pipeline
99
from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule
1010

1111

@@ -19,7 +19,7 @@ def main():
1919
ner_taskmodule = TransformerSpanClassificationTaskModule.from_pretrained(model_name_or_path)
2020
ner_model = TransformerSpanClassificationModel.from_pretrained(model_name_or_path)
2121

22-
ner_pipeline = Pipeline(model=ner_model, taskmodule=ner_taskmodule, device=-1)
22+
ner_pipeline = PyTorchIEPipeline(model=ner_model, taskmodule=ner_taskmodule, device=-1)
2323

2424
document = ExampleDocument(
2525
"“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner at SOSV and managing director of IndieBio."

examples/predict/re_generative.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from pie_core import AnnotationLayer, annotation_field
44

5+
from pytorch_ie import PyTorchIEPipeline
56
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
67
from pytorch_ie.documents import TextDocument
78
from pytorch_ie.models import TransformerSeq2SeqModel
8-
from pytorch_ie.pipeline import Pipeline
99
from pytorch_ie.taskmodules import TransformerSeq2SeqTaskModule
1010

1111

@@ -30,7 +30,7 @@ def main():
3030
model_name_or_path=model_name_or_path,
3131
)
3232

33-
pipeline = Pipeline(model=model, taskmodule=taskmodule, device=-1)
33+
pipeline = PyTorchIEPipeline(model=model, taskmodule=taskmodule, device=-1)
3434

3535
document = ExampleDocument(
3636
"“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner at SOSV and managing director of IndieBio."

examples/predict/re_text_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from pie_core import AnnotationLayer, annotation_field
44

5+
from pytorch_ie import PyTorchIEPipeline
56
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
67
from pytorch_ie.documents import TextDocument
78
from pytorch_ie.models import TransformerTextClassificationModel
8-
from pytorch_ie.pipeline import Pipeline
99
from pytorch_ie.taskmodules import TransformerRETextClassificationTaskModule
1010

1111

@@ -22,7 +22,7 @@ def main():
2222
)
2323
re_model = TransformerTextClassificationModel.from_pretrained(model_name_or_path)
2424

25-
re_pipeline = Pipeline(model=re_model, taskmodule=re_taskmodule, device=-1)
25+
re_pipeline = PyTorchIEPipeline(model=re_model, taskmodule=re_taskmodule, device=-1)
2626

2727
document = ExampleDocument(
2828
"“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner at SOSV and managing director of IndieBio."

src/pytorch_ie/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
# flake8: noqa
22

3-
from pytorch_ie.auto import AutoModel, AutoPipeline, AutoTaskModule
3+
from pytorch_ie.auto import AutoModel, AutoTaskModule
44
from pytorch_ie.core import *
55
from pytorch_ie.datamodule import PieDataModule
66
from pytorch_ie.dataset import IterableTaskEncodingDataset, TaskEncodingDataset
77
from pytorch_ie.pipeline import PyTorchIEPipeline
8-
9-
# kept for backward compatibility
10-
Pipeline = PyTorchIEPipeline

src/pytorch_ie/auto.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,2 @@
11
# kept for backward compatibility
2-
from pie_core import AutoTaskModule
3-
4-
# kept for backward compatibility
5-
from pytorch_ie.model import AutoPyTorchIEModel as AutoModel
6-
from pytorch_ie.pipeline import PyTorchIEPipeline as AutoPipeline
2+
from pie_core import AutoModel, AutoTaskModule

src/pytorch_ie/model.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,3 @@ def predict(self, inputs: Any, **kwargs) -> Any:
3838
outputs = self(inputs, **kwargs)
3939
decoded_outputs = self.decode(inputs=inputs, outputs=outputs)
4040
return decoded_outputs
41-
42-
43-
# TODO: remove this class when all models are registered with @Model.register()
44-
# also see notes in PyTorchIEPipeline
45-
class AutoPyTorchIEModel(Model, Auto[PyTorchIEModel]):
46-
47-
BASE_CLASS = PyTorchIEModel

src/pytorch_ie/models/sequence_classification_with_pooler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
)
1515

1616
import torch
17+
from pie_core import Model
1718
from torch import FloatTensor, LongTensor, nn
1819
from torch.nn import Parameter
1920
from torch.optim import AdamW
2021
from transformers import AutoConfig, AutoModel, PreTrainedModel, get_linear_schedule_with_warmup
2122
from transformers.modeling_outputs import SequenceClassifierOutput
2223
from typing_extensions import TypeAlias
2324

24-
from pytorch_ie import PyTorchIEModel
2525
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
2626

2727
from .common import ModelWithBoilerplate
@@ -236,7 +236,7 @@ def configure_optimizers(self):
236236
return optimizer
237237

238238

239-
@PyTorchIEModel.register()
239+
@Model.register()
240240
class SequenceClassificationModelWithPooler(
241241
SequenceClassificationModelWithPoolerBase,
242242
RequiresNumClasses,
@@ -286,7 +286,7 @@ def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
286286
return {"labels": labels, "probabilities": probabilities}
287287

288288

289-
@PyTorchIEModel.register()
289+
@Model.register()
290290
class SequencePairSimilarityModelWithPooler(
291291
SequenceClassificationModelWithPoolerBase,
292292
):

src/pytorch_ie/models/simple_generative.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Dict, Optional, Tuple, Type, Union
44

55
import torch
6+
from pie_core import Model
67
from pie_core.utils.hydra import resolve_type
78
from pytorch_lightning.utilities.types import OptimizerLRScheduler
89
from torch import FloatTensor, LongTensor
@@ -11,8 +12,6 @@
1112
from transformers.modeling_outputs import Seq2SeqLMOutput
1213
from typing_extensions import TypeAlias
1314

14-
from pytorch_ie import PyTorchIEModel
15-
1615
from .common import ModelWithBoilerplate
1716

1817
logger = logging.getLogger(__name__)
@@ -26,7 +25,7 @@
2625
StepOutputType: TypeAlias = FloatTensor
2726

2827

29-
@PyTorchIEModel.register()
28+
@Model.register()
3029
class SimpleGenerativeModel(
3130
ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType],
3231
):

src/pytorch_ie/models/simple_sequence_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Iterator, MutableMapping, Optional, Tuple, Union
33

44
import torch.nn
5+
from pie_core import Model
56
from torch import FloatTensor, LongTensor
67
from torch.nn import Parameter
78
from torch.optim import AdamW
@@ -13,7 +14,6 @@
1314
from transformers.modeling_outputs import SequenceClassifierOutput
1415
from typing_extensions import TypeAlias
1516

16-
from pytorch_ie import PyTorchIEModel
1717
from pytorch_ie.models.common import ModelWithBoilerplate
1818
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
1919

@@ -29,7 +29,7 @@
2929
logger = logging.getLogger(__name__)
3030

3131

32-
@PyTorchIEModel.register()
32+
@Model.register()
3333
class SimpleSequenceClassificationModel(
3434
ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType],
3535
RequiresModelNameOrPath,

0 commit comments

Comments
 (0)