Skip to content

Commit 6f520f1

Browse files
committed
remove the tensorflow code and simplify the repo to support torch only
1 parent 1f0f3e3 commit 6f520f1

30 files changed

+1070
-3897
lines changed

README.md

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@
5757

5858
- **Configurable and customizable**: models are modularized and configurable,with abstract classes to support developing customized
5959
TPP models.
60-
- **Compatible with both Tensorflow and PyTorch framework**: `EasyTPP` implements two equivalent sets of models, which can
61-
be run under Tensorflow (both Tensorflow 1.13.1 and Tensorflow 2.0) and PyTorch 1.7.0+ respectively. While the PyTorch models are more popular among researchers, the compatibility with Tensorflow is important for industrial practitioners.
60+
- **PyTorch-based implementation**: `EasyTPP` implements state-of-the-art TPP models using PyTorch 1.7.0+, providing a clean and modern deep learning framework.
6261
- **Reproducible**: all the benchmarks can be easily reproduced.
6362
- **Hyper-parameter optimization**: a pipeline of [optuna](https://github.com/optuna/optuna)-based HPO is provided.
6463

@@ -70,14 +69,14 @@ We provide reference implementations of various state-of-the-art TPP papers:
7069

7170
| No | Publication | Model | Paper | Implementation |
7271
|:---:|:-----------:|:-------------:|:-----------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------|
73-
| 1 | KDD'16 | RMTPP | [Recurrent Marked Temporal Point Processes: Embedding Event History to Vector](https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf) | [Tensorflow](easy_tpp/model/tf_model/tf_rmtpp.py)<br/>[Torch](easy_tpp/model/torch_model/torch_rmtpp.py) |
74-
| 2 | NeurIPS'17 | NHP | [The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process](https://arxiv.org/abs/1612.09328) | [Tensorflow](easy_tpp/model/tf_model/tf_nhp.py)<br/>[Torch](easy_tpp/model/torch_model/torch_nhp.py) |
75-
| 3 | NeurIPS'19 | FullyNN | [Fully Neural Network based Model for General Temporal Point Processes](https://arxiv.org/abs/1905.09690) | [Tensorflow](easy_tpp/model/tf_model/tf_fullnn.py)<br/>[Torch](easy_tpp/model/torch_model/torch_fullynn.py) |
76-
| 4 | ICML'20 | SAHP | [Self-Attentive Hawkes process](https://arxiv.org/abs/1907.07561) | [Tensorflow](easy_tpp/model/tf_model/tf_sahp.py)<br/>[Torch](easy_tpp/model/torch_model/torch_sahp.py) |
77-
| 5 | ICML'20 | THP | [Transformer Hawkes process](https://arxiv.org/abs/2002.09291) | [Tensorflow](easy_tpp/model/tf_model/tf_thp.py)<br/>[Torch](easy_tpp/model/torch_model/torch_thp.py) |
78-
| 6 | ICLR'20 | IntensityFree | [Intensity-Free Learning of Temporal Point Processes](https://arxiv.org/abs/1909.12127) | [Tensorflow](easy_tpp/model/tf_model/tf_intensity_free.py)<br/>[Torch](easy_tpp/model/torch_model/torch_intensity_free.py) |
79-
| 7 | ICLR'21 | ODETPP | [Neural Spatio-Temporal Point Processes (simplified)](https://arxiv.org/abs/2011.04583) | [Tensorflow](easy_tpp/model/tf_model/tf_ode_tpp.py)<br/>[Torch](easy_tpp/model/torch_model/torch_ode_tpp.py) |
80-
| 8 | ICLR'22 | AttNHP | [Transformer Embeddings of Irregularly Spaced Events and Their Participants](https://arxiv.org/abs/2201.00044) | [Tensorflow](easy_tpp/model/tf_model/tf_attnhp.py)<br/>[Torch](easy_tpp/model/torch_model/torch_attnhp.py) |
72+
| 1 | KDD'16 | RMTPP | [Recurrent Marked Temporal Point Processes: Embedding Event History to Vector](https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf) | [PyTorch](easy_tpp/model/torch_model/torch_rmtpp.py) |
73+
| 2 | NeurIPS'17 | NHP | [The Neural Hawkes Process: A Neurally Self-Modulating Multivariate Point Process](https://arxiv.org/abs/1612.09328) | [PyTorch](easy_tpp/model/torch_model/torch_nhp.py) |
74+
| 3 | NeurIPS'19 | FullyNN | [Fully Neural Network based Model for General Temporal Point Processes](https://arxiv.org/abs/1905.09690) | [PyTorch](easy_tpp/model/torch_model/torch_fullynn.py) |
75+
| 4 | ICML'20 | SAHP | [Self-Attentive Hawkes process](https://arxiv.org/abs/1907.07561) | [PyTorch](easy_tpp/model/torch_model/torch_sahp.py) |
76+
| 5 | ICML'20 | THP | [Transformer Hawkes process](https://arxiv.org/abs/2002.09291) | [PyTorch](easy_tpp/model/torch_model/torch_thp.py) |
77+
| 6 | ICLR'20 | IntensityFree | [Intensity-Free Learning of Temporal Point Processes](https://arxiv.org/abs/1909.12127) | [PyTorch](easy_tpp/model/torch_model/torch_intensity_free.py) |
78+
| 7 | ICLR'21 | ODETPP | [Neural Spatio-Temporal Point Processes (simplified)](https://arxiv.org/abs/2011.04583) | [PyTorch](easy_tpp/model/torch_model/torch_ode_tpp.py) |
79+
| 8 | ICLR'22 | AttNHP | [Transformer Embeddings of Irregularly Spaced Events and Their Participants](https://arxiv.org/abs/2201.00044) | [PyTorch](easy_tpp/model/torch_model/torch_attnhp.py) |
8180

8281

8382

easy_tpp/config_factory/model_config.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,9 @@ def __init__(self, **kwargs):
144144
def set_backend(backend):
145145
if backend.lower() in ['torch', 'pytorch']:
146146
return Backend.Torch
147-
elif backend.lower() in ['tf', 'tensorflow']:
148-
return Backend.TF
149147
else:
150148
raise ValueError(
151-
f"Backend should be selected between 'torch or pytorch' and 'tf or tensorflow', "
152-
f"current value: {backend}"
149+
f"Backend should be 'torch' or 'pytorch', current value: {backend}"
153150
)
154151

155152
def get_yaml_config(self):

easy_tpp/config_factory/runner_config.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from easy_tpp.config_factory.data_config import DataConfig
66
from easy_tpp.config_factory.model_config import TrainerConfig, ModelConfig, BaseConfig
77
from easy_tpp.utils import create_folder, logger, get_unique_id, get_stage, RunnerPhase, \
8-
MetricsHelper, DefaultRunnerConfig, py_assert, is_torch_available, is_tf_available, is_tf_gpu_available, \
8+
MetricsHelper, DefaultRunnerConfig, py_assert, is_torch_available, \
99
is_torch_gpu_available
1010
from easy_tpp.utils.const import Backend
1111

@@ -119,29 +119,18 @@ def update_config(self):
119119
model_id = self.base_config.model_id
120120
self.model_config.model_id = model_id
121121

122-
if self.base_config.model_id == 'ODETPP' and self.base_config.backend == Backend.TF:
123-
py_assert(self.data_config.data_specs.padding_strategy == 'max_length',
124-
ValueError,
125-
'For ODETPP in TensorFlow, we must pad all sequence to '
126-
'the same length (max len of the sequences)!')
127-
128122
run = current_stage
129123
use_torch = self.base_config.backend == Backend.Torch
130124
device = 'GPU' if self.trainer_config.gpu >= 0 else 'CPU'
131125

132-
py_assert(is_torch_available() if use_torch else is_tf_available(), ValueError,
133-
f'Backend {self.base_config.backend} is not supported in the current environment yet !')
126+
py_assert(is_torch_available(), ValueError,
127+
f'PyTorch is not available in the current environment!')
134128

135129
if use_torch and device == 'GPU':
136130
py_assert(is_torch_gpu_available(),
137131
ValueError,
138132
f'Torch cuda is not supported in the current environment yet!')
139133

140-
if not use_torch and device == 'GPU':
141-
py_assert(is_tf_gpu_available(),
142-
ValueError,
143-
f'Tensorflow GPU is not supported in the current environment yet!')
144-
145134
critical_msg = '{run} model {model_name} using {device} ' \
146135
'with {tf_torch} backend'.format(run=run,
147136
model_name=model_id,

easy_tpp/model/__init__.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,6 @@
99
from easy_tpp.model.torch_model.torch_sahp import SAHP as TorchSAHP
1010
from easy_tpp.model.torch_model.torch_thp import THP as TorchTHP
1111

12-
# by default, we use torch and do not install tf, therefore we ignore the import error
13-
try:
14-
from easy_tpp.model.tf_model.tf_basemodel import TfBaseModel
15-
from easy_tpp.model.tf_model.tf_nhp import NHP as TfNHP
16-
from easy_tpp.model.tf_model.tf_ode_tpp import ODETPP as TfODETPP
17-
from easy_tpp.model.tf_model.tf_thp import THP as TfTHP
18-
from easy_tpp.model.tf_model.tf_sahp import SAHP as TfSAHP
19-
from easy_tpp.model.tf_model.tf_rmtpp import RMTPP as TfRMTPP
20-
from easy_tpp.model.tf_model.tf_attnhp import AttNHP as TfAttNHP
21-
from easy_tpp.model.tf_model.tf_anhn import ANHN as TfANHN
22-
from easy_tpp.model.tf_model.tf_fullynn import FullyNN as TfFullyNN
23-
from easy_tpp.model.tf_model.tf_intensity_free import IntensityFree as TfIntensityFree
24-
except ImportError:
25-
pass
26-
2712
__all__ = ['TorchBaseModel',
2813
'TorchNHP',
2914
'TorchAttNHP',
@@ -32,12 +17,5 @@
3217
'TorchFullyNN',
3318
'TorchIntensityFree',
3419
'TorchODETPP',
35-
'TfBaseModel',
36-
'TfNHP',
37-
'TfAttNHP',
38-
'TfTHP',
39-
'TfSAHP',
40-
'TfANHN',
41-
'TfFullyNN',
42-
'TfIntensityFree',
43-
'TfODETPP']
20+
'TorchRMTPP',
21+
'TorchANHN']

easy_tpp/model/tf_model/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)