-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpermutation_importance_analysis.py
189 lines (132 loc) · 6.77 KB
/
permutation_importance_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
Module to calculate permutation importance
Author: Son Gyo Jung
Email: [email protected]
"""
import os
import pandas as pd
import joblib
import matplotlib.pyplot as plt
from sklearn.inspection import permutation_importance
from sklearn import metrics
from lightgbm.sklearn import LGBMClassifier, LGBMRegressor
from xgboost import XGBClassifier, XGBRegressor
class permutation_importance_analysis():
"""
Class to examine the permutation importance of exploratory features
args:
(1) path_to_file (type:str) - location of the data file with features
(2) path_to_save (type:str) - location to save new data files
(3) path_to_features (type:str) - location of the features to use (e.g. those with multicollinearity reduced)
(4) problem (type:str) - whether it is a 'classification' or 'regression' problem
return:
(1) result of permutation analysis
"""
def __init__(self, path_to_file, path_to_save, path_to_features, problem, *args, **kwargs):
self.path_to_save = path_to_save
self.sample_train = joblib.load(path_to_file)
self.RFE_features = joblib.load(path_to_features)
# Last column taken as the target variable or classes
self.features = self.sample_train.columns.values[:-1]
self.target = self.sample_train.columns.values[-1]
self.problem = problem
def base_model(self, boosting_method):
"""
Select the baseline model
Note:
For classification, multi-class models are defined as shown below
This can be changed into a binary problem by changing the 'objective' to 'binary' for LGBMClassifier, or to 'binary:logistic' or 'binary:logitraw' for XGBClassifier (see description in links below)
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html
https://xgboost.readthedocs.io/en/latest/parameter.html
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMRegressor.html
args:
(1) boosting_method (type:str) - either 'lightGBM' or 'XGBoost'
return:
(1) baseline model
"""
if self.problem == 'classification':
if boosting_method == 'lightGBM':
self.estimator = LGBMClassifier(
boosting_type = 'gbdt',
objective = 'multiclass',
importance_type = 'gain',
max_depth = -1
)
elif boosting_method == 'XGBoost':
self.estimator = XGBClassifier(
objective = 'multi:softprob',
booster = 'gbtree',
importance_type = 'total_gain'
)
elif self.problem == 'regression':
if boosting_method == 'lightGBM':
self.estimator = LGBMRegressor(
boosting_type ='gbdt',
importance_type = 'gain',
max_depth = -1
)
elif boosting_method == 'XGBoost':
self.estimator = XGBClassifier(
objective = 'reg:squarederror',
booster = 'gbtree',
importance_type = 'total_gain'
)
return self.estimator
def calculate(self):
"""
Calculate the permutation importance
"""
# Train baseline model
self.model = self.estimator.fit(
self.sample_train[self.RFE_features],
self.sample_train[self.target].values.ravel()
)
# Define metric to use
if self.problem == 'classification':
self.scoring = 'f1_weighted'
elif self.problem == 'regression':
self.scoring = 'neg_root_mean_squared_error'
# Number of times to permute a feature
n_repeats = 10
# Calculate permutation importance
self.result = permutation_importance(
self.model,
self.sample_train[self.RFE_features],
self.sample_train[self.target].ravel(),
scoring = self.scoring,
n_repeats = n_repeats
)
joblib.dump(self.result, os.path.join(self.path_to_save, r'permutation_importance_result.pkl'))
print('Permutation data saved as: "permutation_importance_result.pkl"')
def permutation_plot(self, no_features = 1):
"""
Generate permutation plot
args:
(1) no_features (type:int) - number of top features to include in the plot
return:
(1) figure of permutation importance plot
"""
# List of feature names
feature_list = list(self.model.feature_name_)
# Indices of features basd on permutation importance
sorted_index = self.result.importances_mean.argsort()
# Permutation importance plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (12, 8))
ax1.boxplot(
self.result.importances[sorted_index].T,
vert = False,
labels = [feature_list[i] for i in sorted_index]
)
ax2.boxplot(
self.result.importances[sorted_index][-no_features:].T,
vert = False,
labels = [feature_list[i] for i in sorted_index][-no_features:]
)
fontsize = 15
ax1.set_xlabel('Reduction in performance metric', fontsize = fontsize)
ax2.set_xlabel('Reduction in performance metric', fontsize = fontsize)
ax1.set_ylabel('Feature number in order of importance', fontsize = fontsize)
fig.tight_layout()
plt.show()
fig.savefig(os.path.join(self.path_to_save, r'permutation_plot.png'), dpi = 300, bbox_inches="tight")
print('Permutation plot saved as: "permutation_plot.png"')