Skip to content

Commit 90a41b4

Browse files
committed
Attempting to scale bandwidths by feature importance
1 parent e66604d commit 90a41b4

File tree

4 files changed

+136
-7
lines changed

4 files changed

+136
-7
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Package information
44
name = 'madml'
5-
version = '0.7.5' # Need to increment every time to push to PyPI
5+
version = '0.7.6' # Need to increment every time to push to PyPI
66
description = 'Application domain of machine learning in materials science.'
77
url = 'https://github.com/leschultz/'\
88
'materials_application_domain_machine_learning.git'
@@ -28,6 +28,7 @@
2828
'tensorflow',
2929
'udocker',
3030
'scikeras',
31+
'seaborn',
3132
]
3233

3334
long_description = open('README.md').read()

src/madml/models/combine.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def __init__(
9797
bins=10,
9898
save=False,
9999
gts=1.0,
100-
gtb=0.25
100+
gtb=0.25,
101+
weigh=None,
101102
):
102103

103104
'''
@@ -110,6 +111,7 @@ def __init__(
110111
save = The location to save figures and data.
111112
gts = The ground truth cutoff for residual magnitude test.
112113
gtb = The ground truth cutoff for statistical test.
114+
weigh = Whether to weight distance features.
113115
'''
114116

115117
self.gs_model = gs_model
@@ -120,6 +122,7 @@ def __init__(
120122
self.splits = copy.deepcopy(splits)
121123
self.gts = gts
122124
self.gtb = gtb
125+
self.weigh = weigh
123126

124127
self.dists = []
125128
self.methods = ['']
@@ -230,7 +233,17 @@ def cv(self, split, gs_model, ds_model, X, y, g):
230233
data['y_stdu'] = self.std_pred(gs_model_cv, X_trans_te)
231234

232235
if self.ds_model:
236+
233237
ds_model_cv = copy.deepcopy(ds_model)
238+
239+
mod_attr = gs_model_cv.best_estimator_.named_steps['model']
240+
attr = dir(mod_attr)
241+
242+
condition = (any([i in attr for i in ['feature_importances_']]))
243+
condition = condition and (self.weigh is True)
244+
if condition:
245+
ds_model_cv.weights = mod_attr.feature_importances_
246+
234247
ds_model_cv.fit(X_trans_tr)
235248

236249
data['dist'] = ds_model_cv.predict(X_trans_te)
@@ -307,6 +320,14 @@ def fit(self, X, y, g):
307320
)
308321

309322
# Fit distance model
323+
mod_attr = self.gs_model.best_estimator_.named_steps['model']
324+
attr = dir(mod_attr)
325+
326+
condition = (any([i in attr for i in ['feature_importances_']]))
327+
condition = condition and (self.weigh is True)
328+
if condition:
329+
self.ds_model.weights = mod_attr.feature_importances_
330+
310331
self.ds_model.fit(X_trans)
311332

312333
out = plots.generate_plots(

src/madml/models/space.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,40 @@
66
import numpy as np
77

88

9+
class weighted_model:
10+
11+
def __init__(self, bandwidth, weights, kernel):
12+
self.bandwidths = bandwidth*weights
13+
self.kernel = kernel
14+
15+
def fit(self, X_train):
16+
self.models = []
17+
for b in range(self.bandwidths.shape[0]):
18+
self.model = KernelDensity(
19+
kernel=self.kernel,
20+
bandwidth=self.bandwidths[b],
21+
).fit(X_train[:, b:b+1])
22+
23+
self.models.append(self.model)
24+
25+
def score_samples(self, X):
26+
scores = []
27+
for b in range(self.bandwidths.shape[0]):
28+
score = self.models[b].score_samples(X[:, b:b+1])
29+
scores.append(score)
30+
31+
return np.sum(scores, axis=0)
32+
33+
def return_bandwidths(self):
34+
return self.bandwidths
35+
36+
937
class distance_model:
1038

11-
def __init__(self, dist='kde', *args, **kwargs):
39+
def __init__(self, dist='kde', weights=None, *args, **kwargs):
1240

1341
self.dist = dist
42+
self.weights = weights
1443
self.args = args
1544
self.kwargs = kwargs
1645

@@ -42,16 +71,26 @@ def fit(
4271
self.bandwidth = estimate_bandwidth(X_train)
4372

4473
# If the estimated bandwidth is zero
45-
if self.bandwidth > 0.0:
74+
if (self.weights is None) and (self.bandwidth == 0.0):
4675
self.model = KernelDensity(
4776
kernel=self.kernel,
48-
bandwidth=self.bandwidth,
4977
).fit(X_train)
50-
else:
78+
self.bandwidth = self.model.bandwidth # Update
79+
80+
elif (self.weights is None) and (self.bandwidth > 0.0):
5181
self.model = KernelDensity(
5282
kernel=self.kernel,
83+
bandwidth=self.bandwidth,
5384
).fit(X_train)
54-
self.bandwidth = self.model.bandwidth # Update
85+
else:
86+
87+
self.model = weighted_model(
88+
self.bandwidth,
89+
self.weights,
90+
self.kernel
91+
)
92+
self.model.fit(X_train)
93+
self.bandwidth = self.model.bandwidths
5594

5695
dist = self.model.score_samples(X_train)
5796
m = max(dist)

src/madml/plots.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from functools import reduce
1212
from sklearn import metrics
1313

14+
import seaborn as sns
1415
import pandas as pd
1516
import numpy as np
1617

@@ -323,10 +324,12 @@ def cdf(x, save=None, binsave=None, subsave='', choice='standard_normal'):
323324

324325
cdf_name = 'cdf'
325326
parity_name = 'cdf_parity'
327+
dist_name = 'distribution'
326328
if binsave is not None:
327329
save = os.path.join(save, 'each_bin')
328330
cdf_name = '{}_{}'.format(cdf_name, binsave)
329331
parity_name = '{}_{}'.format(parity_name, binsave)
332+
dist_name = '{}_{}'.format(dist_name, binsave)
330333

331334
os.makedirs(save, exist_ok=True)
332335

@@ -472,6 +475,71 @@ def cdf(x, save=None, binsave=None, subsave='', choice='standard_normal'):
472475
), 'w') as handle:
473476
json.dump(data, handle)
474477

478+
fig, ax = pl.subplots()
479+
480+
sns.histplot(
481+
z,
482+
kde=True,
483+
stat='density',
484+
color='g',
485+
ax=ax,
486+
label='Standard Normal Distribution',
487+
)
488+
489+
sns.histplot(
490+
x,
491+
kde=True,
492+
stat='density',
493+
color='r',
494+
ax=ax,
495+
label='Observed Distribution',
496+
)
497+
498+
ax.set_xlabel('z')
499+
ax.set_ylabel('Fraction')
500+
501+
fig.tight_layout()
502+
503+
fig_legend, ax_legend = pl.subplots()
504+
ax_legend.axis(False)
505+
legend = ax_legend.legend(
506+
*ax.get_legend_handles_labels(),
507+
frameon=False,
508+
loc='center',
509+
bbox_to_anchor=(0.5, 0.5)
510+
)
511+
ax_legend.spines['top'].set_visible(False)
512+
ax_legend.spines['bottom'].set_visible(False)
513+
ax_legend.spines['left'].set_visible(False)
514+
ax_legend.spines['right'].set_visible(False)
515+
516+
fig.savefig(os.path.join(
517+
save,
518+
'{}{}.png'.format(dist_name, subsave),
519+
), bbox_inches='tight')
520+
521+
fig_legend.savefig(os.path.join(
522+
save,
523+
'{}{}_legend.png'.format(
524+
dist_name,
525+
subsave
526+
),
527+
), bbox_inches='tight')
528+
529+
pl.close(fig)
530+
pl.close(fig_legend)
531+
532+
data = {}
533+
data['x'] = list(eval_points)
534+
data['y'] = list(y)
535+
data['y_pred'] = list(y_pred)
536+
data['Area'] = areacdf
537+
with open(os.path.join(
538+
save,
539+
'{}{}.json'.format(cdf_name, subsave),
540+
), 'w') as handle:
541+
json.dump(data, handle)
542+
475543
return y, y_pred, areaparity, areacdf
476544

477545

0 commit comments

Comments
 (0)