Skip to content

Commit

Permalink
The ground truth is now defined by each models ground truth from trai…
Browse files Browse the repository at this point in the history
…ning data
  • Loading branch information
leschultz committed Jan 31, 2024
1 parent 6456aa0 commit 78e9287
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 59 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Package information
name = 'madml'
version = '2.1.7' # Need to increment every time to push to PyPI
version = '2.1.8' # Need to increment every time to push to PyPI
description = 'Application domain of machine learning in materials science.'
url = 'https://github.com/leschultz/'\
'materials_application_domain_machine_learning.git'
Expand Down
19 changes: 4 additions & 15 deletions src/madml/assess.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def cv(self, split, save_inner_folds=None):
data['r/std_y'] = data['r']/data['std_y']
data['y_stdc_pred/std_y'] = data['y_stdc_pred']/data['std_y']

# Ground truths
data['gt_rmse'] = model.gt_rmse
data['gt_area'] = model.gt_area

return data

def test(
Expand Down Expand Up @@ -135,23 +139,12 @@ def test(

# Full fit
self.model.fit(self.X, self.y, self.g)
self.gt_rmse = self.model.gt_rmse
self.gt_area = self.model.gt_area

pred = self.model.combine_domains_preds(df['d_pred'])
df.drop(pred.columns, axis=1, inplace=True)
df = pd.concat([
df.reset_index(drop=True),
pred.reset_index(drop=True)
], axis=1)

# Ground truths
df, df_bin = bin_data(df, self.model.bins)
df, df_bin = assign_ground_truth(
df,
df_bin,
self.gt_rmse,
self.gt_area,
)

if save_outer_folds is not None:
Expand Down Expand Up @@ -201,8 +194,6 @@ def test(
plot = plotter(
df,
df_bin,
self.gt_rmse,
self.gt_area,
self.model.precs,
ass_save,
)
Expand All @@ -212,8 +203,6 @@ def test(
plot = plotter(
self.model.data_cv,
self.model.bin_cv,
self.model.gt_rmse,
self.model.gt_area,
self.model.precs,
model_ass,
)
Expand Down
45 changes: 22 additions & 23 deletions src/madml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,37 +266,36 @@ def predict_std(model, X):
return std


def assign_ground_truth(data_cv, bin_cv, gt_rmse, gt_area):
def assign_ground_truth(data_cv, bin_cv):

data_cv = copy.deepcopy(data_cv)
bin_cv = copy.deepcopy(bin_cv)

rmse = bin_cv['rmse/std_y'] <= gt_rmse
area = bin_cv['cdf_area'] <= gt_area
data_cv = data_cv.merge(bin_cv, on=['bin'])

bin_cv['domain_rmse/sigma_y'] = np.where(rmse, 'ID', 'OD')
bin_cv['domain_cdf_area'] = np.where(area, 'ID', 'OD')
# Innitiate arrays
cols = ['gt_rmse', 'gt_area']
for c in cols:
bin_cv[c] = None

cols = [
'domain_rmse/sigma_y',
'domain_cdf_area',
'rmse/std_y',
'cdf_area',
]
# Propagate ground truths
for group, value in data_cv.groupby(['bin', *cols]):
row = bin_cv['bin'] == group[0]
bin_cv.loc[row, 'gt_rmse'] = group[1]
bin_cv.loc[row, 'gt_area'] = group[2]

# Allocate data
for col in cols:
data_cv[col] = None
# Make labels
rmse = data_cv['rmse/std_y'] <= data_cv['gt_rmse']
area = data_cv['cdf_area'] <= data_cv['gt_area']

# Assign bin data to individual points
for i in bin_cv.bin:
data_cv['domain_rmse/sigma_y'] = np.where(rmse, 'ID', 'OD')
data_cv['domain_cdf_area'] = np.where(area, 'ID', 'OD')

# Ground labels based on rmse
row = data_cv['bin'] == i
gt = bin_cv.loc[bin_cv['bin'] == i][cols]
rmse = bin_cv['rmse/std_y'] <= bin_cv['gt_rmse']
area = bin_cv['cdf_area'] <= bin_cv['gt_area']

for col in cols:
data_cv.loc[row, col] = gt[col].values[0]
bin_cv['domain_rmse/sigma_y'] = np.where(rmse, 'ID', 'OD')
bin_cv['domain_cdf_area'] = np.where(area, 'ID', 'OD')

return data_cv, bin_cv

Expand Down Expand Up @@ -472,13 +471,13 @@ def fit(self, X, y, g=None, d_input=None):

# Acquire ground truths
self = ground_truth(self, y)
data_cv['gt_rmse'] = self.gt_rmse
data_cv['gt_area'] = self.gt_area

# Classify ground truth labels
data_cv, bin_cv = assign_ground_truth(
data_cv,
bin_cv,
self.gt_rmse,
self.gt_area,
)

# Fit domain classifiers
Expand Down
28 changes: 8 additions & 20 deletions src/madml/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,21 +251,20 @@ def cdf(df, gt, save, suffix):
plot_dump(data, fig, ax, 'cdf', save, suffix)


def bins(df, d, e, elabel, gt, ylabel, save, suffix):
def bins(df, d, e, elabel, ylabel, save, suffix):
'''
Plot statistical errors with respect to dissimilarity.
inputs:
d = The dissimilarity.
e = The error statistic.
elabel = The domain labels.
gt = The domain ground truth.
ylabel = The y-axis label.
save = The directory to save plot.
suffix = Append a suffix to the save name.
'''

data = {'gt': gt}
data = {}
fig, ax = pl.subplots()
for group, values in df.groupby([elabel, 'bin']):

Expand All @@ -291,12 +290,6 @@ def bins(df, d, e, elabel, gt, ylabel, save, suffix):
data[dom]['x'] = x.tolist()
data[dom]['y'] = y.tolist()

ax.axhline(
gt,
color='g',
label='GT = {:.2f}'.format(gt),
)

ax.set_ylabel(ylabel)
ax.set_xlabel('D')

Expand Down Expand Up @@ -495,8 +488,6 @@ def __init__(
self,
df,
df_bin,
gt_rmse,
gt_area,
precs,
save,
):
Expand All @@ -505,10 +496,9 @@ def __init__(
self.domains = ['domain_rmse/sigma_y', 'domain_cdf_area']
self.errors = ['rmse/std_y', 'cdf_area']
self.assessments = ['rmse', 'area']
self.gts = [gt_rmse, gt_area] # Ground truths
self.precs = precs # Precisions used

# For plotting purposes
# For plotting purposes on the histogram of E^* vs. D
cols = self.errors+self.domains
self.df = df.sort_values(by=['d_pred']+cols)
self.df_bin = df_bin.sort_values(by=['d_pred_max']+cols)
Expand Down Expand Up @@ -548,12 +538,11 @@ def generate(self):
area_vs_rmse(self.df_bin, self.save)

# Loop over domains
for i, j, k, f, in zip(
self.domains,
self.errors,
self.assessments,
self.gts,
):
for i, j, k, in zip(
self.domains,
self.errors,
self.assessments,
):

# Separate domains and classes
for group, df in self.df.groupby(i):
Expand Down Expand Up @@ -588,7 +577,6 @@ def generate(self):
'd_pred',
j,
i,
f,
r'$E^{{{}}}$'.format(k),
self.save,
k,
Expand Down

0 comments on commit 78e9287

Please sign in to comment.