Skip to content

Commit 6d41f82

Browse files
authored
Merge pull request #327 from WenjieDu/(feat)add_dlinear
Add DLinear as an imputation model
2 parents 4425c5b + a812c42 commit 6d41f82

File tree

10 files changed

+575
-18
lines changed

10 files changed

+575
-18
lines changed

pypots/imputation/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .transformer import Transformer
1515
from .timesnet import TimesNet
1616
from .autoformer import Autoformer
17+
from .dlinear import DLinear
1718
from .patchtst import PatchTST
1819
from .usgan import USGAN
1920

@@ -28,6 +29,7 @@
2829
"Transformer",
2930
"TimesNet",
3031
"PatchTST",
32+
"DLinear",
3133
"Autoformer",
3234
"BRITS",
3335
"MRNN",

pypots/imputation/autoformer/model.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,12 @@ class Autoformer(BaseNNImputer):
101101
better than in previous epochs.
102102
The "all" strategy will save every model after each epoch training.
103103
104-
Attributes
104+
References
105105
----------
106-
model : :class:`torch.nn.Module`
107-
The underlying Transformer model.
108-
109-
optimizer : :class:`pypots.optim.Optimizer`
110-
The optimizer for model training.
106+
.. [1] `Wu, Haixu, Jiehui Xu, Jianmin Wang, and Mingsheng Long.
107+
"Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting".
108+
Advances in neural information processing systems 34 (2021): 22419-22430.
109+
<https://proceedings.neurips.cc/paper/2021/file/bcc0d400288793e8bdcd7c19a8ac0c2b-Paper.pdf>`_
111110
112111
"""
113112

pypots/imputation/dlinear/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
The package of the partially-observed time-series imputation model DLinear.
3+
4+
Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021).
5+
DLinear: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.".
6+
7+
"""
8+
9+
# Created by Wenjie Du <[email protected]>
10+
# License: BSD-3-Clause
11+
12+
13+
from .model import DLinear
14+
15+
__all__ = [
16+
"DLinear",
17+
]

pypots/imputation/dlinear/data.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""
2+
Dataset class for DLinear.
3+
"""
4+
5+
# Created by Wenjie Du <[email protected]>
6+
# License: BSD-3-Clause
7+
8+
from typing import Union
9+
10+
from ..saits.data import DatasetForSAITS
11+
12+
13+
class DatasetForDLinear(DatasetForSAITS):
14+
"""Actually DLinear uses the same data strategy as SAITS, needs MIT for training."""
15+
16+
def __init__(
17+
self,
18+
data: Union[dict, str],
19+
return_X_ori: bool,
20+
return_labels: bool,
21+
file_type: str = "h5py",
22+
rate: float = 0.2,
23+
):
24+
super().__init__(data, return_X_ori, return_labels, file_type, rate)

pypots/imputation/dlinear/model.py

+296
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
"""
2+
The implementation of DLinear for the partially-observed time-series imputation task.
3+
4+
Refer to the paper "Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2023).
5+
Are transformers effective for time series forecasting? AAAI 2023".
6+
7+
Notes
8+
-----
9+
Partial implementation uses code from https://github.com/thuml/Time-Series-Library
10+
11+
"""
12+
13+
# Created by Wenjie Du <[email protected]>
14+
# License: BSD-3-Clause
15+
16+
from typing import Union, Optional
17+
18+
import numpy as np
19+
import torch
20+
from torch.utils.data import DataLoader
21+
22+
from .data import DatasetForDLinear
23+
from .modules.core import _DLinear
24+
from ..base import BaseNNImputer
25+
from ...data.base import BaseDataset
26+
from ...data.checking import check_X_ori_in_val_set
27+
from ...optim.adam import Adam
28+
from ...optim.base import Optimizer
29+
from ...utils.logging import logger
30+
31+
32+
class DLinear(BaseNNImputer):
33+
"""The PyTorch implementation of the DLinear model.
34+
DLinear is originally proposed by Zeng et al. in :cite:`zeng2023dlinear`.
35+
36+
Parameters
37+
----------
38+
n_steps :
39+
The number of time steps in the time-series data sample.
40+
41+
n_features :
42+
The number of features in the time-series data sample.
43+
44+
moving_avg_window_size :
45+
The window size of moving average.
46+
47+
individual :
48+
Whether to share model across different features.
49+
50+
batch_size :
51+
The batch size for training and evaluating the model.
52+
53+
epochs :
54+
The number of epochs for training the model.
55+
56+
patience :
57+
The patience for the early-stopping mechanism. Given a positive integer, the training process will be
58+
stopped when the model does not perform better after that number of epochs.
59+
Leaving it default as None will disable the early-stopping.
60+
61+
optimizer :
62+
The optimizer for model training.
63+
If not given, will use a default Adam optimizer.
64+
65+
num_workers :
66+
The number of subprocesses to use for data loading.
67+
`0` means data loading will be in the main process, i.e. there won't be subprocesses.
68+
69+
device :
70+
The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
71+
If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
72+
then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
73+
If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
74+
model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
75+
Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
76+
77+
saving_path :
78+
The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
79+
training into a tensorboard file). Will not save if not given.
80+
81+
model_saving_strategy :
82+
The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
83+
No model will be saved when it is set as None.
84+
The "best" strategy will only automatically save the best model after the training finished.
85+
The "better" strategy will automatically save the model during training whenever the model performs
86+
better than in previous epochs.
87+
The "all" strategy will save every model after each epoch training.
88+
89+
References
90+
----------
91+
.. [1] `Zeng, Ailing, Muxi Chen, Lei Zhang, and Qiang Xu.
92+
"Are transformers effective for time series forecasting?".
93+
In Proceedings of the AAAI conference on artificial intelligence, vol. 37, no. 9, pp. 11121-11128. 2023.
94+
<https://ojs.aaai.org/index.php/AAAI/article/view/26317/26089>`_
95+
96+
"""
97+
98+
def __init__(
99+
self,
100+
n_steps: int,
101+
n_features: int,
102+
moving_avg_window_size: int,
103+
individual: bool = False,
104+
batch_size: int = 32,
105+
epochs: int = 100,
106+
patience: int = None,
107+
optimizer: Optional[Optimizer] = Adam(),
108+
num_workers: int = 0,
109+
device: Optional[Union[str, torch.device, list]] = None,
110+
saving_path: str = None,
111+
model_saving_strategy: Optional[str] = "best",
112+
):
113+
super().__init__(
114+
batch_size,
115+
epochs,
116+
patience,
117+
num_workers,
118+
device,
119+
saving_path,
120+
model_saving_strategy,
121+
)
122+
123+
self.n_steps = n_steps
124+
self.n_features = n_features
125+
# model hype-parameters
126+
self.moving_avg_window_size = moving_avg_window_size
127+
self.individual = individual
128+
129+
# set up the model
130+
self.model = _DLinear(
131+
n_steps,
132+
n_features,
133+
moving_avg_window_size,
134+
individual,
135+
)
136+
self._send_model_to_given_device()
137+
self._print_model_size()
138+
139+
# set up the optimizer
140+
self.optimizer = optimizer
141+
self.optimizer.init_optimizer(self.model.parameters())
142+
143+
def _assemble_input_for_training(self, data: list) -> dict:
144+
(
145+
indices,
146+
X,
147+
missing_mask,
148+
X_ori,
149+
indicating_mask,
150+
) = self._send_data_to_given_device(data)
151+
152+
inputs = {
153+
"X": X,
154+
"missing_mask": missing_mask,
155+
"X_ori": X_ori,
156+
"indicating_mask": indicating_mask,
157+
}
158+
159+
return inputs
160+
161+
def _assemble_input_for_validating(self, data: list) -> dict:
162+
return self._assemble_input_for_training(data)
163+
164+
def _assemble_input_for_testing(self, data: list) -> dict:
165+
indices, X, missing_mask = self._send_data_to_given_device(data)
166+
167+
inputs = {
168+
"X": X,
169+
"missing_mask": missing_mask,
170+
}
171+
172+
return inputs
173+
174+
def fit(
175+
self,
176+
train_set: Union[dict, str],
177+
val_set: Optional[Union[dict, str]] = None,
178+
file_type: str = "h5py",
179+
) -> None:
180+
# Step 1: wrap the input data with classes Dataset and DataLoader
181+
training_set = DatasetForDLinear(
182+
train_set, return_X_ori=False, return_labels=False, file_type=file_type
183+
)
184+
training_loader = DataLoader(
185+
training_set,
186+
batch_size=self.batch_size,
187+
shuffle=True,
188+
num_workers=self.num_workers,
189+
)
190+
val_loader = None
191+
if val_set is not None:
192+
if not check_X_ori_in_val_set(val_set):
193+
raise ValueError("val_set must contain 'X_ori' for model validation.")
194+
val_set = DatasetForDLinear(
195+
val_set, return_X_ori=True, return_labels=False, file_type=file_type
196+
)
197+
val_loader = DataLoader(
198+
val_set,
199+
batch_size=self.batch_size,
200+
shuffle=False,
201+
num_workers=self.num_workers,
202+
)
203+
204+
# Step 2: train the model and freeze it
205+
self._train_model(training_loader, val_loader)
206+
self.model.load_state_dict(self.best_model_dict)
207+
self.model.eval() # set the model as eval status to freeze it.
208+
209+
# Step 3: save the model if necessary
210+
self._auto_save_model_if_necessary(confirm_saving=True)
211+
212+
def predict(
213+
self,
214+
test_set: Union[dict, str],
215+
file_type: str = "h5py",
216+
) -> dict:
217+
"""Make predictions for the input data with the trained model.
218+
219+
Parameters
220+
----------
221+
test_set : dict or str
222+
The dataset for model validating, should be a dictionary including keys as 'X',
223+
or a path string locating a data file supported by PyPOTS (e.g. h5 file).
224+
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
225+
which is time-series data for validating, can contain missing values, and y should be array-like of shape
226+
[n_samples], which is classification labels of X.
227+
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
228+
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
229+
230+
file_type : str
231+
The type of the given file if test_set is a path string.
232+
233+
Returns
234+
-------
235+
result_dict : dict,
236+
The dictionary containing the clustering results and latent variables if necessary.
237+
238+
"""
239+
# Step 1: wrap the input data with classes Dataset and DataLoader
240+
self.model.eval() # set the model as eval status to freeze it.
241+
test_set = BaseDataset(
242+
test_set, return_X_ori=False, return_labels=False, file_type=file_type
243+
)
244+
test_loader = DataLoader(
245+
test_set,
246+
batch_size=self.batch_size,
247+
shuffle=False,
248+
num_workers=self.num_workers,
249+
)
250+
imputation_collector = []
251+
252+
# Step 2: process the data with the model
253+
with torch.no_grad():
254+
for idx, data in enumerate(test_loader):
255+
inputs = self._assemble_input_for_testing(data)
256+
results = self.model.forward(inputs, training=False)
257+
imputation_collector.append(results["imputed_data"])
258+
259+
# Step 3: output collection and return
260+
imputation = torch.cat(imputation_collector).cpu().detach().numpy()
261+
result_dict = {
262+
"imputation": imputation,
263+
}
264+
return result_dict
265+
266+
def impute(
267+
self,
268+
X: Union[dict, str],
269+
file_type="h5py",
270+
) -> np.ndarray:
271+
"""Impute missing values in the given data with the trained model.
272+
273+
Warnings
274+
--------
275+
The method impute is deprecated. Please use `predict()` instead.
276+
277+
Parameters
278+
----------
279+
X :
280+
The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
281+
n_features], or a path string locating a data file, e.g. h5 file.
282+
283+
file_type :
284+
The type of the given file if X is a path string.
285+
286+
Returns
287+
-------
288+
array-like, shape [n_samples, sequence length (time steps), n_features],
289+
Imputed data.
290+
"""
291+
logger.warning(
292+
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
293+
)
294+
295+
results_dict = self.predict(X, file_type=file_type)
296+
return results_dict["imputation"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""
2+
3+
"""
4+
5+
# Created by Wenjie Du <[email protected]>
6+
# License: BSD-3-Clause

0 commit comments

Comments
 (0)