|
| 1 | +""" |
| 2 | +The implementation of DLinear for the partially-observed time-series imputation task. |
| 3 | +
|
| 4 | +Refer to the paper "Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2023). |
| 5 | +Are transformers effective for time series forecasting? AAAI 2023". |
| 6 | +
|
| 7 | +Notes |
| 8 | +----- |
| 9 | +Partial implementation uses code from https://github.com/thuml/Time-Series-Library |
| 10 | +
|
| 11 | +""" |
| 12 | + |
| 13 | +# Created by Wenjie Du <[email protected]> |
| 14 | +# License: BSD-3-Clause |
| 15 | + |
| 16 | +from typing import Union, Optional |
| 17 | + |
| 18 | +import numpy as np |
| 19 | +import torch |
| 20 | +from torch.utils.data import DataLoader |
| 21 | + |
| 22 | +from .data import DatasetForDLinear |
| 23 | +from .modules.core import _DLinear |
| 24 | +from ..base import BaseNNImputer |
| 25 | +from ...data.base import BaseDataset |
| 26 | +from ...data.checking import check_X_ori_in_val_set |
| 27 | +from ...optim.adam import Adam |
| 28 | +from ...optim.base import Optimizer |
| 29 | +from ...utils.logging import logger |
| 30 | + |
| 31 | + |
| 32 | +class DLinear(BaseNNImputer): |
| 33 | + """The PyTorch implementation of the DLinear model. |
| 34 | + DLinear is originally proposed by Zeng et al. in :cite:`zeng2023dlinear`. |
| 35 | +
|
| 36 | + Parameters |
| 37 | + ---------- |
| 38 | + n_steps : |
| 39 | + The number of time steps in the time-series data sample. |
| 40 | +
|
| 41 | + n_features : |
| 42 | + The number of features in the time-series data sample. |
| 43 | +
|
| 44 | + moving_avg_window_size : |
| 45 | + The window size of moving average. |
| 46 | +
|
| 47 | + individual : |
| 48 | + Whether to share model across different features. |
| 49 | +
|
| 50 | + batch_size : |
| 51 | + The batch size for training and evaluating the model. |
| 52 | +
|
| 53 | + epochs : |
| 54 | + The number of epochs for training the model. |
| 55 | +
|
| 56 | + patience : |
| 57 | + The patience for the early-stopping mechanism. Given a positive integer, the training process will be |
| 58 | + stopped when the model does not perform better after that number of epochs. |
| 59 | + Leaving it default as None will disable the early-stopping. |
| 60 | +
|
| 61 | + optimizer : |
| 62 | + The optimizer for model training. |
| 63 | + If not given, will use a default Adam optimizer. |
| 64 | +
|
| 65 | + num_workers : |
| 66 | + The number of subprocesses to use for data loading. |
| 67 | + `0` means data loading will be in the main process, i.e. there won't be subprocesses. |
| 68 | +
|
| 69 | + device : |
| 70 | + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. |
| 71 | + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), |
| 72 | + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. |
| 73 | + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the |
| 74 | + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). |
| 75 | + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. |
| 76 | +
|
| 77 | + saving_path : |
| 78 | + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during |
| 79 | + training into a tensorboard file). Will not save if not given. |
| 80 | +
|
| 81 | + model_saving_strategy : |
| 82 | + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. |
| 83 | + No model will be saved when it is set as None. |
| 84 | + The "best" strategy will only automatically save the best model after the training finished. |
| 85 | + The "better" strategy will automatically save the model during training whenever the model performs |
| 86 | + better than in previous epochs. |
| 87 | + The "all" strategy will save every model after each epoch training. |
| 88 | +
|
| 89 | + References |
| 90 | + ---------- |
| 91 | + .. [1] `Zeng, Ailing, Muxi Chen, Lei Zhang, and Qiang Xu. |
| 92 | + "Are transformers effective for time series forecasting?". |
| 93 | + In Proceedings of the AAAI conference on artificial intelligence, vol. 37, no. 9, pp. 11121-11128. 2023. |
| 94 | + <https://ojs.aaai.org/index.php/AAAI/article/view/26317/26089>`_ |
| 95 | +
|
| 96 | + """ |
| 97 | + |
| 98 | + def __init__( |
| 99 | + self, |
| 100 | + n_steps: int, |
| 101 | + n_features: int, |
| 102 | + moving_avg_window_size: int, |
| 103 | + individual: bool = False, |
| 104 | + batch_size: int = 32, |
| 105 | + epochs: int = 100, |
| 106 | + patience: int = None, |
| 107 | + optimizer: Optional[Optimizer] = Adam(), |
| 108 | + num_workers: int = 0, |
| 109 | + device: Optional[Union[str, torch.device, list]] = None, |
| 110 | + saving_path: str = None, |
| 111 | + model_saving_strategy: Optional[str] = "best", |
| 112 | + ): |
| 113 | + super().__init__( |
| 114 | + batch_size, |
| 115 | + epochs, |
| 116 | + patience, |
| 117 | + num_workers, |
| 118 | + device, |
| 119 | + saving_path, |
| 120 | + model_saving_strategy, |
| 121 | + ) |
| 122 | + |
| 123 | + self.n_steps = n_steps |
| 124 | + self.n_features = n_features |
| 125 | + # model hype-parameters |
| 126 | + self.moving_avg_window_size = moving_avg_window_size |
| 127 | + self.individual = individual |
| 128 | + |
| 129 | + # set up the model |
| 130 | + self.model = _DLinear( |
| 131 | + n_steps, |
| 132 | + n_features, |
| 133 | + moving_avg_window_size, |
| 134 | + individual, |
| 135 | + ) |
| 136 | + self._send_model_to_given_device() |
| 137 | + self._print_model_size() |
| 138 | + |
| 139 | + # set up the optimizer |
| 140 | + self.optimizer = optimizer |
| 141 | + self.optimizer.init_optimizer(self.model.parameters()) |
| 142 | + |
| 143 | + def _assemble_input_for_training(self, data: list) -> dict: |
| 144 | + ( |
| 145 | + indices, |
| 146 | + X, |
| 147 | + missing_mask, |
| 148 | + X_ori, |
| 149 | + indicating_mask, |
| 150 | + ) = self._send_data_to_given_device(data) |
| 151 | + |
| 152 | + inputs = { |
| 153 | + "X": X, |
| 154 | + "missing_mask": missing_mask, |
| 155 | + "X_ori": X_ori, |
| 156 | + "indicating_mask": indicating_mask, |
| 157 | + } |
| 158 | + |
| 159 | + return inputs |
| 160 | + |
| 161 | + def _assemble_input_for_validating(self, data: list) -> dict: |
| 162 | + return self._assemble_input_for_training(data) |
| 163 | + |
| 164 | + def _assemble_input_for_testing(self, data: list) -> dict: |
| 165 | + indices, X, missing_mask = self._send_data_to_given_device(data) |
| 166 | + |
| 167 | + inputs = { |
| 168 | + "X": X, |
| 169 | + "missing_mask": missing_mask, |
| 170 | + } |
| 171 | + |
| 172 | + return inputs |
| 173 | + |
| 174 | + def fit( |
| 175 | + self, |
| 176 | + train_set: Union[dict, str], |
| 177 | + val_set: Optional[Union[dict, str]] = None, |
| 178 | + file_type: str = "h5py", |
| 179 | + ) -> None: |
| 180 | + # Step 1: wrap the input data with classes Dataset and DataLoader |
| 181 | + training_set = DatasetForDLinear( |
| 182 | + train_set, return_X_ori=False, return_labels=False, file_type=file_type |
| 183 | + ) |
| 184 | + training_loader = DataLoader( |
| 185 | + training_set, |
| 186 | + batch_size=self.batch_size, |
| 187 | + shuffle=True, |
| 188 | + num_workers=self.num_workers, |
| 189 | + ) |
| 190 | + val_loader = None |
| 191 | + if val_set is not None: |
| 192 | + if not check_X_ori_in_val_set(val_set): |
| 193 | + raise ValueError("val_set must contain 'X_ori' for model validation.") |
| 194 | + val_set = DatasetForDLinear( |
| 195 | + val_set, return_X_ori=True, return_labels=False, file_type=file_type |
| 196 | + ) |
| 197 | + val_loader = DataLoader( |
| 198 | + val_set, |
| 199 | + batch_size=self.batch_size, |
| 200 | + shuffle=False, |
| 201 | + num_workers=self.num_workers, |
| 202 | + ) |
| 203 | + |
| 204 | + # Step 2: train the model and freeze it |
| 205 | + self._train_model(training_loader, val_loader) |
| 206 | + self.model.load_state_dict(self.best_model_dict) |
| 207 | + self.model.eval() # set the model as eval status to freeze it. |
| 208 | + |
| 209 | + # Step 3: save the model if necessary |
| 210 | + self._auto_save_model_if_necessary(confirm_saving=True) |
| 211 | + |
| 212 | + def predict( |
| 213 | + self, |
| 214 | + test_set: Union[dict, str], |
| 215 | + file_type: str = "h5py", |
| 216 | + ) -> dict: |
| 217 | + """Make predictions for the input data with the trained model. |
| 218 | +
|
| 219 | + Parameters |
| 220 | + ---------- |
| 221 | + test_set : dict or str |
| 222 | + The dataset for model validating, should be a dictionary including keys as 'X', |
| 223 | + or a path string locating a data file supported by PyPOTS (e.g. h5 file). |
| 224 | + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], |
| 225 | + which is time-series data for validating, can contain missing values, and y should be array-like of shape |
| 226 | + [n_samples], which is classification labels of X. |
| 227 | + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains |
| 228 | + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. |
| 229 | +
|
| 230 | + file_type : str |
| 231 | + The type of the given file if test_set is a path string. |
| 232 | +
|
| 233 | + Returns |
| 234 | + ------- |
| 235 | + result_dict : dict, |
| 236 | + The dictionary containing the clustering results and latent variables if necessary. |
| 237 | +
|
| 238 | + """ |
| 239 | + # Step 1: wrap the input data with classes Dataset and DataLoader |
| 240 | + self.model.eval() # set the model as eval status to freeze it. |
| 241 | + test_set = BaseDataset( |
| 242 | + test_set, return_X_ori=False, return_labels=False, file_type=file_type |
| 243 | + ) |
| 244 | + test_loader = DataLoader( |
| 245 | + test_set, |
| 246 | + batch_size=self.batch_size, |
| 247 | + shuffle=False, |
| 248 | + num_workers=self.num_workers, |
| 249 | + ) |
| 250 | + imputation_collector = [] |
| 251 | + |
| 252 | + # Step 2: process the data with the model |
| 253 | + with torch.no_grad(): |
| 254 | + for idx, data in enumerate(test_loader): |
| 255 | + inputs = self._assemble_input_for_testing(data) |
| 256 | + results = self.model.forward(inputs, training=False) |
| 257 | + imputation_collector.append(results["imputed_data"]) |
| 258 | + |
| 259 | + # Step 3: output collection and return |
| 260 | + imputation = torch.cat(imputation_collector).cpu().detach().numpy() |
| 261 | + result_dict = { |
| 262 | + "imputation": imputation, |
| 263 | + } |
| 264 | + return result_dict |
| 265 | + |
| 266 | + def impute( |
| 267 | + self, |
| 268 | + X: Union[dict, str], |
| 269 | + file_type="h5py", |
| 270 | + ) -> np.ndarray: |
| 271 | + """Impute missing values in the given data with the trained model. |
| 272 | +
|
| 273 | + Warnings |
| 274 | + -------- |
| 275 | + The method impute is deprecated. Please use `predict()` instead. |
| 276 | +
|
| 277 | + Parameters |
| 278 | + ---------- |
| 279 | + X : |
| 280 | + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), |
| 281 | + n_features], or a path string locating a data file, e.g. h5 file. |
| 282 | +
|
| 283 | + file_type : |
| 284 | + The type of the given file if X is a path string. |
| 285 | +
|
| 286 | + Returns |
| 287 | + ------- |
| 288 | + array-like, shape [n_samples, sequence length (time steps), n_features], |
| 289 | + Imputed data. |
| 290 | + """ |
| 291 | + logger.warning( |
| 292 | + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." |
| 293 | + ) |
| 294 | + |
| 295 | + results_dict = self.predict(X, file_type=file_type) |
| 296 | + return results_dict["imputation"] |
0 commit comments