-
Notifications
You must be signed in to change notification settings - Fork 0
/
feature_base.py
37 lines (29 loc) · 1.3 KB
/
feature_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from abc import ABCMeta, abstractmethod
import os
from scipy.sparse import isspmatrix_csr, isspmatrix_csc
from scipy.sparse import save_npz, load_npz
class BaseFeature(metaclass=ABCMeta):
def __init__(self):
self.name = None
self.description = 'Base class of features.'
self.config = {}
self._save_path = './features'
if not os.path.exists(self._save_path):
os.mkdir(self._save_path)
@abstractmethod
def build(self):
raise NotImplementedError("Must override function 'build'")
def load(self):
Xtrain_file = os.path.join(self._save_path, ('Xtrain_' + self.name + '.npz'))
Xtest_file = os.path.join(self._save_path, ('Xtest_' + self.name + '.npz'))
if not os.path.isfile(Xtrain_file) or not os.path.isfile(Xtest_file):
Xtrain, Xtest = self.build()
if isspmatrix_csr(Xtrain) or isspmatrix_csc(Xtrain):
save_npz(Xtrain_file, Xtrain)
if isspmatrix_csr(Xtest) or isspmatrix_csc(Xtest):
save_npz(Xtest_file, Xtest)
else:
Xtrain = load_npz(Xtrain_file)
Xtest = load_npz(Xtest_file)
print('Feature: [{}], Train shape: [{}], Test shape: [{}]'.format(self.description, Xtrain.shape, Xtest.shape))
return Xtrain, Xtest