Skip to content

Commit fd678cb

Browse files
author
zz1409
committed
update code
1 parent 17364b5 commit fd678cb

7 files changed

+1270
-8
lines changed

enc_model.py

+69
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,74 @@ def _run(self, Note, Num, Disease, Mask, Age, Demo):
110110
'Demo': torch.from_numpy(Demo.reshape(1, -1)).float()
111111
}
112112

113+
class staticDataset(Dataset):
114+
115+
# Each record is: [[text_enc1, text_enc2, ...], [num_enc1, num_enc2, ...], disease, mask, age, gender, race, eth]
116+
def __init__(self, root_dir, dsName, nClassGender, nClassRace, nClassEthnic, transform=None):
117+
self.root_dir = root_dir
118+
self.ds = json.load(open(root_dir + dsName, 'r'))
119+
print('Loaded: ', dsName)
120+
self.transform = transform
121+
self.nClass = [nClassGender, nClassRace, nClassEthnic]
122+
self.max_len = 3000
123+
124+
def __len__(self):
125+
return len(self.ds)
126+
127+
128+
def __getitem__(self, idx):
129+
130+
"""
131+
Shape of the inputs:
132+
- Note: 1 x n_enc x noteLength
133+
- Num: 1 x n_enc x dimNum
134+
- Disease: 1 x 3
135+
- Mask: 1 x 3
136+
- Age: 1
137+
- Demo ([Gender, race, eth]): 1 x 3
138+
"""
139+
140+
Note, Num, Disease, Mask, Age, gender, race, eth = self.ds[idx]
141+
142+
Note = np.asarray([item for sublist in Note for item in sublist])
143+
Num = np.asarray(Num, dtype='float32').mean(axis=0)
144+
Disease = np.asarray(Disease, dtype='int')
145+
Mask = np.asarray(Mask, dtype='int')
146+
Age = np.asarray(Age, dtype='float32')
147+
148+
if len(Note) > self.max_len:
149+
Note = Note[:self.max_len]
150+
151+
else:
152+
Note = np.concatenate([ np.zeros( self.max_len - Note.shape[0] ) , Note ])
153+
154+
Note = torch.from_numpy(Note).long()
155+
156+
gender2 = self._idx2onehot(gender, self.nClass[0])
157+
race2 = self._idx2onehot(race, self.nClass[1])
158+
eth2 = self._idx2onehot(eth, self.nClass[2])
159+
160+
Demo = np.concatenate([gender2, race2, eth2])
161+
sample = {'Note': Note, 'Num': Num, 'Disease': Disease, 'Mask': Mask, 'Age': Age, 'Demo': Demo}
162+
163+
return {'Note': Note,
164+
'Num': Num,
165+
'Disease': torch.from_numpy(Disease.reshape(1, -1)).float(),
166+
'Mask': torch.from_numpy(Mask.reshape(1, -1)).float(),
167+
'Age': torch.from_numpy(Age.reshape(1, -1)).float(),
168+
'Demo': torch.from_numpy(Demo.reshape(1, -1)).float()
169+
}
170+
171+
def _idx2onehot(self, value_idx, max_idx):
172+
173+
temp = np.zeros(max_idx)
174+
if value_idx > 0:
175+
temp[value_idx - 1] = 1
176+
return temp
177+
178+
179+
180+
113181

114182
class padOrTruncateToTensor(object):
115183
"""
@@ -669,6 +737,7 @@ def _train(self, epoch, lsTrainAccuracy):
669737
Note, Num, Disease, Mask, Age, Demo = Variable(Note).long(), Variable(Num).float(), Variable(
670738
Disease).float(), Variable(Mask).float(), Variable(Age).float(), Variable(Demo).float()
671739

740+
672741
self.cnt_iter += 1
673742

674743
if self.flg_cuda:

0 commit comments

Comments
 (0)