Skip to content

Commit

Permalink
Added residual domain
Browse files Browse the repository at this point in the history
  • Loading branch information
leschultz committed May 1, 2024
1 parent 14b0af7 commit b788bc9
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 61 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Package information
name = 'madml'
version = '2.6.1' # Need to increment every time to push to PyPI
version = '2.6.2' # Need to increment every time to push to PyPI
description = 'Application domain of machine learning in materials science.'
url = 'https://github.com/leschultz/'\
'materials_application_domain_machine_learning.git'
Expand Down
21 changes: 15 additions & 6 deletions src/madml/assess.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def cv(self, split, save_inner_folds=None):
# Predictions
data['r'] = self.y[test]-data['y_pred']
data['z'] = data['r']/data['y_stdc_pred']
data['|r|'] = data['r'].abs()
data['|r|/std_y'] = data['|r|']/data['std_y']
data['|r|/mad_y'] = data['|r|']/data['mad_y']
data['absres'] = data['r'].abs()
data['absres/std_y'] = data['absres']/data['std_y']
data['absres/mad_y'] = data['absres']/data['mad_y']
data['y_stdc_pred/std_y'] = data['y_stdc_pred']/data['std_y']

# Ground truths
Expand Down Expand Up @@ -177,6 +177,7 @@ def test(

# Acquire ground truths
self = ground_truth(self, self.y)
df['gt_absres'] = self.gt_absres
df['gt_rmse'] = self.gt_rmse
df['gt_area'] = self.gt_area

Expand All @@ -198,7 +199,11 @@ def test(
cols = p.columns
cols = [i.split(' (')[0] for i in cols]
p.columns = cols
d = df[['domain_rmse/std_y', 'domain_cdf_area']]
d = df[[
'domain_absres/mad_y',
'domain_rmse/std_y',
'domain_cdf_area',
]]

d = pd.concat([
d.reset_index(drop=True),
Expand All @@ -215,12 +220,16 @@ def test(
self.model.fit(self.X, self.y, self.g, n_jobs=self.n_jobs)

# Refit on out-of-bag data for final classification models
self.model.domain_absres.fit(
df['d_pred'].values,
df['domain_absres/mad_y'].values,
)
self.model.domain_rmse.fit(
df['d_pred'].values,
df['d_pred_max'].values,
df['domain_rmse/std_y'].values,
)
self.model.domain_area.fit(
df['d_pred'].values,
df['d_pred_max'].values,
df['domain_cdf_area'].values,
)

Expand Down
4 changes: 2 additions & 2 deletions src/madml/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def bin_data(data_cv, bins, by='d_pred'):
binmax = bin_groups['d_pred'].max()
counts = bin_groups['z'].count()
stdc = bin_groups['y_stdc_pred/std_y'].mean()
rmse = bin_groups['|r|/std_y'].apply(lambda x: (sum(x**2)/len(x))**0.5)
rmse = bin_groups['absres/std_y'].apply(lambda x: (sum(x**2)/len(x))**0.5)

area = bin_groups.apply(lambda x: cdf(
x['z'],
Expand All @@ -259,7 +259,7 @@ def bin_data(data_cv, bins, by='d_pred'):
distmean = distmean.to_frame().add_suffix('_mean')
binmax = binmax.to_frame().add_suffix('_max')
stdc = stdc.to_frame().add_suffix('_mean')
rmse = rmse.to_frame().rename({'|r|/std_y': 'rmse/std_y'}, axis=1)
rmse = rmse.to_frame().rename({'absres/std_y': 'rmse/std_y'}, axis=1)
counts = counts.to_frame().rename({'z': 'count'}, axis=1)

# Combine data for each bin
Expand Down
8 changes: 4 additions & 4 deletions src/madml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def assign_ground_truth(data_cv, bin_cv):
bin_cv.loc[row, 'gt_area'] = group[3]

# Make labels
absres = data_cv['|r|/mad_y'] <= data_cv['gt_absres']
absres = data_cv['absres/mad_y'] <= data_cv['gt_absres']
rmse = data_cv['rmse/std_y'] <= data_cv['gt_rmse']
area = data_cv['cdf_area'] <= data_cv['gt_area']

Expand Down Expand Up @@ -476,9 +476,9 @@ def cv(self, split, gs_model, ds_model, X, y, g=None):
data['y_stdu_pred'] = predict_std(gs_model_cv, X_trans_te)
data['d_pred'] = ds_model_cv.predict(X_trans_te)
data['r'] = y[te]-data['y_pred']
data['|r|'] = data['r'].abs()
data['|r|/std_y'] = data['|r|']/data['std_y']
data['|r|/mad_y'] = data['|r|']/data['mad_y']
data['absres'] = data['r'].abs()
data['absres/std_y'] = data['absres']/data['std_y']
data['absres/mad_y'] = data['absres']/data['mad_y']

return data

Expand Down
74 changes: 26 additions & 48 deletions src/madml/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,33 +72,6 @@ def plot_dump(data, fig, ax, name, save, suffix, legend=True):
json.dump(data, handle)


def residuals(df, save='.', suffix='d'):
'''
A plot of absolute residuals vs. dissimilarity.
inputs:
df = Data.
save = The directory to save plot.
suffix = Append a suffix to the save name.
'''

data = {}
fig, ax = pl.subplots()

x = df['d_pred'].values
y = df['|r|/mad_y'].values

ax.scatter(x, y, marker='.', color='r')

data['x'] = x.tolist()
data['y'] = y.tolist()

ax.set_xlabel(r'$d$')
ax.set_ylabel(r'$|y-\hat{y}|/MAD_{y}$')

plot_dump(data, fig, ax, 'residuals', save, suffix, False)


def confidence(df, save='.', suffix='all'):
'''
A plot of absolute residuals vs. dissimilarity.
Expand Down Expand Up @@ -159,7 +132,7 @@ def sub(x, y, ylabel, key, gt, gtlabel, metric, color):
gt_area = float(df['gt_area'].min()) # Should all be same

sorters = {
r'$|y-\hat{y}|$': df['|r|'].values,
r'$|y-\hat{y}|$': df['absres'].values,
r'$d$': df['d_pred'].values,
r'$\sigma_{c}$': df['y_stdc_pred'].values,
'Random': np.random.uniform(size=df.shape[0]),
Expand All @@ -176,7 +149,7 @@ def sub(x, y, ylabel, key, gt, gtlabel, metric, color):

indx = np.argsort(value)
d = value[indx]
y = df['|r|/std_y'].values[indx]
y = df['absres/std_y'].values[indx]
z = df['z'].values[indx]

out = parallel(
Expand Down Expand Up @@ -277,7 +250,7 @@ def parity(
y = df.y
y_pred = df.y_pred
y_stdc_pred = df.y_stdc_pred
r_std_y = df['|r|/std_y']
r_std_y = df['absres/std_y']
d = df.d_pred

rmse = metrics.mean_squared_error(y, y_pred)**0.5
Expand Down Expand Up @@ -468,16 +441,17 @@ def bins(df, d, e, elabel, gt, ylabel, gtlabel, save, suffix):
ax.scatter(
x,
y,
alpha=0.5,
alpha=0.4,
**p[dom],
)

ax.fill_between(
x,
y,
color=p[dom]['color'],
alpha=0.5,
)
if suffix != 'absres':
ax.fill_between(
x,
y,
color=p[dom]['color'],
alpha=0.5,
)

data[dom]['x'].append(x.tolist())
data[dom]['y'].append(y.tolist())
Expand Down Expand Up @@ -715,16 +689,17 @@ def __init__(
):

self.save = save
self.domains = ['domain_rmse/std_y', 'domain_cdf_area']
self.errors = ['rmse/std_y', 'cdf_area']
self.gts = ['gt_rmse', 'gt_area']
self.assessments = ['rmse', 'area']
self.errors = ['absres/mad_y', 'rmse/std_y', 'cdf_area']
self.domains = ['domain_'+i for i in self.errors]
self.assessments = ['absres', 'rmse', 'area']
self.gts = ['gt_'+i for i in self.assessments]
self.precs = precs # Precisions used

# For plotting purposes on the histogram of E^* vs. D
cols = self.errors+self.domains
self.df = df.sort_values(by=['d_pred']+cols)
self.df_bin = df_bin.sort_values(by=['d_pred_max']+cols)
df_cols = self.errors+self.domains
bin_cols = [i for i in df_cols if 'absres' not in i]
self.df = df.sort_values(by=['d_pred']+df_cols)
self.df_bin = df_bin.sort_values(by=['d_pred_max']+bin_cols)

self.bins = self.df_bin.shape[0]

Expand All @@ -738,6 +713,8 @@ def __init__(
self.df_confusion = df_confusion
self.pred_cols = [i.split(' (')[0] for i in pred_cols]

self.pred_cols = set(self.pred_cols)

def generate(self):

# Write test data
Expand All @@ -755,9 +732,6 @@ def generate(self):
'fit_splitter',
)

# Residuals
residuals(self.df, self.save)

# Confidence
confidence(self.df, self.save)
confidence(df, self.save, 'fit_splitter')
Expand Down Expand Up @@ -799,6 +773,9 @@ def generate(self):
elif k == 'area':
ename = r'$E^{area}$'
cname = r'$E^{area}_{c}$'
elif k == 'absres':
ename = r'$E^{|y-\hat{y}|/MAD_{y}}$'
cname = r'$E^{|y-\hat{y}|/MAD_{y}}_{c}$'
else:
raise 'Unsupported error metric'

Expand Down Expand Up @@ -852,7 +829,8 @@ def generate(self):

# Confusion matrices
for pred in self.pred_cols:
if i.replace('domain_', '') in pred:

if j in pred:

# Confusion matrix for all splitters
y = self.df_confusion.loc[:, i].values
Expand Down

0 comments on commit b788bc9

Please sign in to comment.