Skip to content

Commit dee8da7

Browse files
committed
Fixed json bug
1 parent b788bc9 commit dee8da7

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Package information
44
name = 'madml'
5-
version = '2.6.2' # Need to increment every time to push to PyPI
5+
version = '2.6.3' # Need to increment every time to push to PyPI
66
description = 'Application domain of machine learning in materials science.'
77
url = 'https://github.com/leschultz/'\
88
'materials_application_domain_machine_learning.git'

src/madml/models.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def predict(self, d, d_input=None):
276276

277277
key = 'Domain Prediction from {} (p={},r={})'.format(key, p, r)
278278
cut = value['Threshold']
279-
do_pred[key] = np.where(d <= cut, 'ID', 'OD')
279+
do_pred[key] = np.where(d < cut, 'ID', 'OD')
280280

281281
if d_input is not None:
282282
do_pred['d_input'] = np.where(d <= d_input, 'ID', 'OD')
@@ -349,16 +349,16 @@ def assign_ground_truth(data_cv, bin_cv):
349349
bin_cv.loc[row, 'gt_area'] = group[3]
350350

351351
# Make labels
352-
absres = data_cv['absres/mad_y'] <= data_cv['gt_absres']
353-
rmse = data_cv['rmse/std_y'] <= data_cv['gt_rmse']
354-
area = data_cv['cdf_area'] <= data_cv['gt_area']
352+
absres = data_cv['absres/mad_y'] < data_cv['gt_absres']
353+
rmse = data_cv['rmse/std_y'] < data_cv['gt_rmse']
354+
area = data_cv['cdf_area'] < data_cv['gt_area']
355355

356356
data_cv['domain_absres/mad_y'] = np.where(absres, 'ID', 'OD')
357357
data_cv['domain_rmse/std_y'] = np.where(rmse, 'ID', 'OD')
358358
data_cv['domain_cdf_area'] = np.where(area, 'ID', 'OD')
359359

360-
rmse = bin_cv['rmse/std_y'] <= bin_cv['gt_rmse']
361-
area = bin_cv['cdf_area'] <= bin_cv['gt_area']
360+
rmse = bin_cv['rmse/std_y'] < bin_cv['gt_rmse']
361+
area = bin_cv['cdf_area'] < bin_cv['gt_area']
362362

363363
bin_cv['domain_rmse/std_y'] = np.where(rmse, 'ID', 'OD')
364364
bin_cv['domain_cdf_area'] = np.where(area, 'ID', 'OD')

src/madml/plots.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,12 @@ def sub(x, y, ylabel, key, gt, gtlabel, metric, color):
118118
ax_sub.set_xlabel(f'Included {key}')
119119
ax_sub.set_ylabel(ylabel)
120120

121-
dat = {'x': x, 'y': y, 'gt': gt}
121+
dat = {
122+
'x': list(map(float, x)),
123+
'y': list(map(float, y)),
124+
'gt': float(gt),
125+
}
126+
122127
plot_dump(
123128
dat,
124129
fig_sub,

0 commit comments

Comments
 (0)