Skip to content

Commit 1eecbbf

Browse files
authored
Merge pull request #537 from lss-1138/main
Add SegRNN Implementation
2 parents 903bdab + 366a584 commit 1eecbbf

File tree

22 files changed

+2777
-0
lines changed

22 files changed

+2777
-0
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ The paper references and links are all listed at the bottom of this file.
126126
| Neural Net | iTransformer🧑‍🔧[^24] || | | | | `2024 - ICLR` |
127127
| Neural Net | ModernTCN[^38] || | | | | `2024 - ICLR` |
128128
| Neural Net | ImputeFormer🧑‍🔧[^34] || | | | | `2024 - KDD` |
129+
| Neural Net | SegRNN[^42] || | | | | `2023 - arXiv` |
129130
| Neural Net | SAITS[^1] || | | | | `2023 - ESWA` |
130131
| Neural Net | FreTS🧑‍🔧[^23] || | | | | `2023 - NeurIPS` |
131132
| Neural Net | Koopa🧑‍🔧[^29] || | | | | `2023 - NeurIPS` |
@@ -509,3 +510,6 @@ Time-Series.AI</a>
509510
[^41]: Xu, Z., Zeng, A., & Xu, Q. (2024).
510511
[FITS: Modeling Time Series with 10k parameters](https://openreview.net/forum?id=bWcnvZ3qMb).
511512
*ICLR 2024*.
513+
[^42]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023).
514+
[Segrnn: Segment recurrent neural network for long-term time series forecasting](https://github.com/lss-1138/SegRNN)
515+
*arXiv 2023*.

pypots/classification/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
# License: BSD-3-Clause
77

88
from .brits import BRITS
9+
from .csai import CSAI
910
from .grud import GRUD
1011
from .raindrop import Raindrop
1112

1213
__all__ = [
14+
"CSAI",
1315
"BRITS",
1416
"GRUD",
1517
"Raindrop",
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""
2+
The package including the modules of CSAI.
3+
4+
Refer to the paper
5+
`Linglong Qian, Zina Ibrahim, Hugh Logan Ellis, Ao Zhang, Yuezhou Zhang, Tao Wang, Richard Dobson.
6+
Knowledge Enhanced Conditional Imputation for Healthcare Time-series.
7+
In Arxiv, 2024.
8+
<https://arxiv.org/abs/2312.16713>`_
9+
10+
Notes
11+
-----
12+
This implementation is inspired by the official one the official implementation https://github.com/LinglongQian/CSAI.
13+
14+
"""
15+
16+
from .model import CSAI
17+
18+
__all__ = [
19+
"CSAI",
20+
]

pypots/classification/csai/core.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
3+
"""
4+
5+
# Created by Linglong Qian, Joseph Arul Raj <[email protected], [email protected]>
6+
# License: BSD-3-Clause
7+
8+
import torch
9+
import torch.nn as nn
10+
import torch.nn.functional as F
11+
12+
from ...nn.modules.csai import BackboneBCSAI
13+
14+
# class DiceBCELoss(nn.Module):
15+
# def __init__(self, weight=None, size_average=True):
16+
# super(DiceBCELoss, self).__init__()
17+
# self.bcelogits = nn.BCEWithLogitsLoss()
18+
19+
# def forward(self, y_score, y_out, targets, smooth=1):
20+
21+
# #comment out if your model contains a sigmoid or equivalent activation layer
22+
# # inputs = F.sigmoid(inputs)
23+
24+
# #flatten label and prediction tensors
25+
# BCE = self.bcelogits(y_out, targets)
26+
27+
# y_score = y_score.view(-1)
28+
# targets = targets.view(-1)
29+
# intersection = (y_score * targets).sum()
30+
# dice_loss = 1 - (2.*intersection + smooth)/(y_score.sum() + targets.sum() + smooth)
31+
32+
# Dice_BCE = BCE + dice_loss
33+
34+
# return BCE, Dice_BCE
35+
36+
37+
class _BCSAI(nn.Module):
38+
def __init__(
39+
self,
40+
n_steps: int,
41+
n_features: int,
42+
rnn_hidden_size: int,
43+
imputation_weight: float,
44+
consistency_weight: float,
45+
classification_weight: float,
46+
n_classes: int,
47+
step_channels: int,
48+
dropout: float = 0.5,
49+
intervals=None,
50+
):
51+
super().__init__()
52+
self.n_steps = n_steps
53+
self.n_features = n_features
54+
self.rnn_hidden_size = rnn_hidden_size
55+
self.imputation_weight = imputation_weight
56+
self.consistency_weight = consistency_weight
57+
self.classification_weight = classification_weight
58+
self.n_classes = n_classes
59+
self.step_channels = step_channels
60+
self.intervals = intervals
61+
62+
# create models
63+
self.model = BackboneBCSAI(n_steps, n_features, rnn_hidden_size, step_channels, intervals)
64+
self.f_classifier = nn.Linear(self.rnn_hidden_size, n_classes)
65+
self.b_classifier = nn.Linear(self.rnn_hidden_size, n_classes)
66+
self.imputer = nn.Linear(self.rnn_hidden_size, n_features)
67+
self.dropout = nn.Dropout(dropout)
68+
69+
def forward(self, inputs: dict, training: bool = True) -> dict:
70+
71+
(
72+
imputed_data,
73+
f_reconstruction,
74+
b_reconstruction,
75+
f_hidden_states,
76+
b_hidden_states,
77+
consistency_loss,
78+
reconstruction_loss,
79+
) = self.model(inputs)
80+
81+
results = {
82+
"imputed_data": imputed_data,
83+
}
84+
85+
f_logits = self.f_classifier(self.dropout(f_hidden_states))
86+
b_logits = self.b_classifier(self.dropout(b_hidden_states))
87+
88+
# f_prediction = torch.sigmoid(f_logits)
89+
# b_prediction = torch.sigmoid(b_logits)
90+
91+
f_prediction = torch.softmax(f_logits, dim=1)
92+
b_prediction = torch.softmax(b_logits, dim=1)
93+
classification_pred = (f_prediction + b_prediction) / 2
94+
95+
results = {
96+
"imputed_data": imputed_data,
97+
"classification_pred": classification_pred,
98+
}
99+
100+
# if in training mode, return results with losses
101+
if training:
102+
# criterion = DiceBCELoss().to(imputed_data.device)
103+
results["consistency_loss"] = consistency_loss
104+
results["reconstruction_loss"] = reconstruction_loss
105+
# print(inputs["labels"].unsqueeze(1))
106+
f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["labels"])
107+
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["labels"])
108+
# f_classification_loss, _ = criterion(f_prediction, f_logits, inputs["labels"].unsqueeze(1).float())
109+
# b_classification_loss, _ = criterion(b_prediction, b_logits, inputs["labels"].unsqueeze(1).float())
110+
classification_loss = (f_classification_loss + b_classification_loss)
111+
112+
loss = (
113+
self.consistency_weight * consistency_loss +
114+
self.imputation_weight * reconstruction_loss +
115+
self.classification_weight * classification_loss
116+
)
117+
118+
results["loss"] = loss
119+
results["classification_loss"] = classification_loss
120+
results["f_reconstruction"] = f_reconstruction
121+
results["b_reconstruction"] = b_reconstruction
122+
123+
return results

pypots/classification/csai/data.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
3+
"""
4+
5+
# Created by Joseph Arul Raj <[email protected]>
6+
# License: BSD-3-Clause
7+
8+
from typing import Union
9+
from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation
10+
11+
12+
13+
class DatasetForCSAI(DatasetForCSAI_Imputation):
14+
def __init__(self,
15+
data: Union[dict, str],
16+
file_type: str = "hdf5",
17+
return_y: bool = True,
18+
removal_percent: float = 0.0,
19+
increase_factor: float = 0.1,
20+
compute_intervals: bool = False,
21+
replacement_probabilities = None,
22+
normalise_mean : list = [],
23+
normalise_std: list = [],
24+
training: bool = True
25+
):
26+
super().__init__(
27+
data=data,
28+
return_X_ori=False,
29+
return_y=return_y,
30+
file_type=file_type,
31+
removal_percent=removal_percent,
32+
increase_factor=increase_factor,
33+
compute_intervals=compute_intervals,
34+
replacement_probabilities=replacement_probabilities,
35+
normalise_mean=normalise_mean,
36+
normalise_std=normalise_std,
37+
training=training
38+
)
39+

0 commit comments

Comments
 (0)