Skip to content

Commit

Permalink
Fixed json bug
Browse files Browse the repository at this point in the history
  • Loading branch information
leschultz committed May 2, 2024
1 parent b788bc9 commit dee8da7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Package information
name = 'madml'
version = '2.6.2' # Need to increment every time to push to PyPI
version = '2.6.3' # Need to increment every time to push to PyPI
description = 'Application domain of machine learning in materials science.'
url = 'https://github.com/leschultz/'\
'materials_application_domain_machine_learning.git'
Expand Down
12 changes: 6 additions & 6 deletions src/madml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def predict(self, d, d_input=None):

key = 'Domain Prediction from {} (p={},r={})'.format(key, p, r)
cut = value['Threshold']
do_pred[key] = np.where(d <= cut, 'ID', 'OD')
do_pred[key] = np.where(d < cut, 'ID', 'OD')

if d_input is not None:
do_pred['d_input'] = np.where(d <= d_input, 'ID', 'OD')
Expand Down Expand Up @@ -349,16 +349,16 @@ def assign_ground_truth(data_cv, bin_cv):
bin_cv.loc[row, 'gt_area'] = group[3]

# Make labels
absres = data_cv['absres/mad_y'] <= data_cv['gt_absres']
rmse = data_cv['rmse/std_y'] <= data_cv['gt_rmse']
area = data_cv['cdf_area'] <= data_cv['gt_area']
absres = data_cv['absres/mad_y'] < data_cv['gt_absres']
rmse = data_cv['rmse/std_y'] < data_cv['gt_rmse']
area = data_cv['cdf_area'] < data_cv['gt_area']

data_cv['domain_absres/mad_y'] = np.where(absres, 'ID', 'OD')
data_cv['domain_rmse/std_y'] = np.where(rmse, 'ID', 'OD')
data_cv['domain_cdf_area'] = np.where(area, 'ID', 'OD')

rmse = bin_cv['rmse/std_y'] <= bin_cv['gt_rmse']
area = bin_cv['cdf_area'] <= bin_cv['gt_area']
rmse = bin_cv['rmse/std_y'] < bin_cv['gt_rmse']
area = bin_cv['cdf_area'] < bin_cv['gt_area']

bin_cv['domain_rmse/std_y'] = np.where(rmse, 'ID', 'OD')
bin_cv['domain_cdf_area'] = np.where(area, 'ID', 'OD')
Expand Down
7 changes: 6 additions & 1 deletion src/madml/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,12 @@ def sub(x, y, ylabel, key, gt, gtlabel, metric, color):
ax_sub.set_xlabel(f'Included {key}')
ax_sub.set_ylabel(ylabel)

dat = {'x': x, 'y': y, 'gt': gt}
dat = {
'x': list(map(float, x)),
'y': list(map(float, y)),
'gt': float(gt),
}

plot_dump(
dat,
fig_sub,
Expand Down

0 comments on commit dee8da7

Please sign in to comment.