1+ """
2+
3+ """
4+
5+ # Created by Linglong Qian, Joseph Arul Raj <[email protected] , [email protected] > 6+ # License: BSD-3-Clause
7+
8+ import torch
9+ import torch .nn as nn
10+ import torch .nn .functional as F
11+
12+ from ...nn .modules .csai import BackboneBCSAI
13+
14+ # class DiceBCELoss(nn.Module):
15+ # def __init__(self, weight=None, size_average=True):
16+ # super(DiceBCELoss, self).__init__()
17+ # self.bcelogits = nn.BCEWithLogitsLoss()
18+
19+ # def forward(self, y_score, y_out, targets, smooth=1):
20+
21+ # #comment out if your model contains a sigmoid or equivalent activation layer
22+ # # inputs = F.sigmoid(inputs)
23+
24+ # #flatten label and prediction tensors
25+ # BCE = self.bcelogits(y_out, targets)
26+
27+ # y_score = y_score.view(-1)
28+ # targets = targets.view(-1)
29+ # intersection = (y_score * targets).sum()
30+ # dice_loss = 1 - (2.*intersection + smooth)/(y_score.sum() + targets.sum() + smooth)
31+
32+ # Dice_BCE = BCE + dice_loss
33+
34+ # return BCE, Dice_BCE
35+
36+
37+ class _BCSAI (nn .Module ):
38+ def __init__ (
39+ self ,
40+ n_steps : int ,
41+ n_features : int ,
42+ rnn_hidden_size : int ,
43+ imputation_weight : float ,
44+ consistency_weight : float ,
45+ classification_weight : float ,
46+ n_classes : int ,
47+ step_channels : int ,
48+ dropout : float = 0.5 ,
49+ intervals = None ,
50+ ):
51+ super ().__init__ ()
52+ self .n_steps = n_steps
53+ self .n_features = n_features
54+ self .rnn_hidden_size = rnn_hidden_size
55+ self .imputation_weight = imputation_weight
56+ self .consistency_weight = consistency_weight
57+ self .classification_weight = classification_weight
58+ self .n_classes = n_classes
59+ self .step_channels = step_channels
60+ self .intervals = intervals
61+
62+ # create models
63+ self .model = BackboneBCSAI (n_steps , n_features , rnn_hidden_size , step_channels , intervals )
64+ self .f_classifier = nn .Linear (self .rnn_hidden_size , n_classes )
65+ self .b_classifier = nn .Linear (self .rnn_hidden_size , n_classes )
66+ self .imputer = nn .Linear (self .rnn_hidden_size , n_features )
67+ self .dropout = nn .Dropout (dropout )
68+
69+ def forward (self , inputs : dict , training : bool = True ) -> dict :
70+
71+ (
72+ imputed_data ,
73+ f_reconstruction ,
74+ b_reconstruction ,
75+ f_hidden_states ,
76+ b_hidden_states ,
77+ consistency_loss ,
78+ reconstruction_loss ,
79+ ) = self .model (inputs )
80+
81+ results = {
82+ "imputed_data" : imputed_data ,
83+ }
84+
85+ f_logits = self .f_classifier (self .dropout (f_hidden_states ))
86+ b_logits = self .b_classifier (self .dropout (b_hidden_states ))
87+
88+ # f_prediction = torch.sigmoid(f_logits)
89+ # b_prediction = torch.sigmoid(b_logits)
90+
91+ f_prediction = torch .softmax (f_logits , dim = 1 )
92+ b_prediction = torch .softmax (b_logits , dim = 1 )
93+ classification_pred = (f_prediction + b_prediction ) / 2
94+
95+ results = {
96+ "imputed_data" : imputed_data ,
97+ "classification_pred" : classification_pred ,
98+ }
99+
100+ # if in training mode, return results with losses
101+ if training :
102+ # criterion = DiceBCELoss().to(imputed_data.device)
103+ results ["consistency_loss" ] = consistency_loss
104+ results ["reconstruction_loss" ] = reconstruction_loss
105+ # print(inputs["labels"].unsqueeze(1))
106+ f_classification_loss = F .nll_loss (torch .log (f_prediction ), inputs ["labels" ])
107+ b_classification_loss = F .nll_loss (torch .log (b_prediction ), inputs ["labels" ])
108+ # f_classification_loss, _ = criterion(f_prediction, f_logits, inputs["labels"].unsqueeze(1).float())
109+ # b_classification_loss, _ = criterion(b_prediction, b_logits, inputs["labels"].unsqueeze(1).float())
110+ classification_loss = (f_classification_loss + b_classification_loss )
111+
112+ loss = (
113+ self .consistency_weight * consistency_loss +
114+ self .imputation_weight * reconstruction_loss +
115+ self .classification_weight * classification_loss
116+ )
117+
118+ results ["loss" ] = loss
119+ results ["classification_loss" ] = classification_loss
120+ results ["f_reconstruction" ] = f_reconstruction
121+ results ["b_reconstruction" ] = b_reconstruction
122+
123+ return results
0 commit comments