-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataloader.py
30 lines (22 loc) · 1.1 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
def read_bci_data():
S4b_train = np.load('S4b_train.npz')
X11b_train = np.load('X11b_train.npz')
S4b_test = np.load('S4b_test.npz')
X11b_test = np.load('X11b_test.npz')
train_data = np.concatenate((S4b_train['signal'], X11b_train['signal']), axis=0)
train_label = np.concatenate((S4b_train['label'], X11b_train['label']), axis=0)
test_data = np.concatenate((S4b_test['signal'], X11b_test['signal']), axis=0)
test_label = np.concatenate((S4b_test['label'], X11b_test['label']), axis=0)
train_label = train_label - 1
test_label = test_label -1
train_data = np.transpose(np.expand_dims(train_data, axis=1), (0, 1, 3, 2))
test_data = np.transpose(np.expand_dims(test_data, axis=1), (0, 1, 3, 2))
mask = np.where(np.isnan(train_data))
train_data[mask] = np.nanmean(train_data)
mask = np.where(np.isnan(test_data))
test_data[mask] = np.nanmean(test_data)
# print(train_data.shape, train_label.shape, test_data.shape, test_label.shape)
return train_data, train_label, test_data, test_label
if __name__ == "__main__":
read_bci_data()