@@ -110,6 +110,74 @@ def _run(self, Note, Num, Disease, Mask, Age, Demo):
110
110
'Demo' : torch .from_numpy (Demo .reshape (1 , - 1 )).float ()
111
111
}
112
112
113
+ class staticDataset (Dataset ):
114
+
115
+ # Each record is: [[text_enc1, text_enc2, ...], [num_enc1, num_enc2, ...], disease, mask, age, gender, race, eth]
116
+ def __init__ (self , root_dir , dsName , nClassGender , nClassRace , nClassEthnic , transform = None ):
117
+ self .root_dir = root_dir
118
+ self .ds = json .load (open (root_dir + dsName , 'r' ))
119
+ print ('Loaded: ' , dsName )
120
+ self .transform = transform
121
+ self .nClass = [nClassGender , nClassRace , nClassEthnic ]
122
+ self .max_len = 3000
123
+
124
+ def __len__ (self ):
125
+ return len (self .ds )
126
+
127
+
128
+ def __getitem__ (self , idx ):
129
+
130
+ """
131
+ Shape of the inputs:
132
+ - Note: 1 x n_enc x noteLength
133
+ - Num: 1 x n_enc x dimNum
134
+ - Disease: 1 x 3
135
+ - Mask: 1 x 3
136
+ - Age: 1
137
+ - Demo ([Gender, race, eth]): 1 x 3
138
+ """
139
+
140
+ Note , Num , Disease , Mask , Age , gender , race , eth = self .ds [idx ]
141
+
142
+ Note = np .asarray ([item for sublist in Note for item in sublist ])
143
+ Num = np .asarray (Num , dtype = 'float32' ).mean (axis = 0 )
144
+ Disease = np .asarray (Disease , dtype = 'int' )
145
+ Mask = np .asarray (Mask , dtype = 'int' )
146
+ Age = np .asarray (Age , dtype = 'float32' )
147
+
148
+ if len (Note ) > self .max_len :
149
+ Note = Note [:self .max_len ]
150
+
151
+ else :
152
+ Note = np .concatenate ([ np .zeros ( self .max_len - Note .shape [0 ] ) , Note ])
153
+
154
+ Note = torch .from_numpy (Note ).long ()
155
+
156
+ gender2 = self ._idx2onehot (gender , self .nClass [0 ])
157
+ race2 = self ._idx2onehot (race , self .nClass [1 ])
158
+ eth2 = self ._idx2onehot (eth , self .nClass [2 ])
159
+
160
+ Demo = np .concatenate ([gender2 , race2 , eth2 ])
161
+ sample = {'Note' : Note , 'Num' : Num , 'Disease' : Disease , 'Mask' : Mask , 'Age' : Age , 'Demo' : Demo }
162
+
163
+ return {'Note' : Note ,
164
+ 'Num' : Num ,
165
+ 'Disease' : torch .from_numpy (Disease .reshape (1 , - 1 )).float (),
166
+ 'Mask' : torch .from_numpy (Mask .reshape (1 , - 1 )).float (),
167
+ 'Age' : torch .from_numpy (Age .reshape (1 , - 1 )).float (),
168
+ 'Demo' : torch .from_numpy (Demo .reshape (1 , - 1 )).float ()
169
+ }
170
+
171
+ def _idx2onehot (self , value_idx , max_idx ):
172
+
173
+ temp = np .zeros (max_idx )
174
+ if value_idx > 0 :
175
+ temp [value_idx - 1 ] = 1
176
+ return temp
177
+
178
+
179
+
180
+
113
181
114
182
class padOrTruncateToTensor (object ):
115
183
"""
@@ -669,6 +737,7 @@ def _train(self, epoch, lsTrainAccuracy):
669
737
Note , Num , Disease , Mask , Age , Demo = Variable (Note ).long (), Variable (Num ).float (), Variable (
670
738
Disease ).float (), Variable (Mask ).float (), Variable (Age ).float (), Variable (Demo ).float ()
671
739
740
+
672
741
self .cnt_iter += 1
673
742
674
743
if self .flg_cuda :
0 commit comments