Skip to content

Commit

Permalink
Now feel good about confusion plots
Browse files Browse the repository at this point in the history
  • Loading branch information
leschultz committed Jan 30, 2024
1 parent 890bba5 commit 6c103f8
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Package information
name = 'madml'
version = '2.1.0' # Need to increment every time to push to PyPI
version = '2.1.2' # Need to increment every time to push to PyPI
description = 'Application domain of machine learning in materials science.'
url = 'https://github.com/leschultz/'\
'materials_application_domain_machine_learning.git'
Expand Down
1 change: 0 additions & 1 deletion src/madml/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def pr(d, labels, precs):

# Maximum F1 score
max_f1_index = np.argmax(f1_scores)
print(f1_scores, max_f1_index)

data['Max F1'] = {
'Precision': precision[max_f1_index],
Expand Down
8 changes: 7 additions & 1 deletion src/madml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def cv(self, split, gs_model, ds_model, X, y, g=None):

return data

def fit(self, X, y, g=None):
def fit(self, X, y, g=None, d_input=None):
'''
Fit all models. Thresholds for domain classification are also set.
Expand Down Expand Up @@ -506,6 +506,12 @@ def fit(self, X, y, g=None):
bin_cv['domain_cdf_area'].values,
)

pred = self.combine_domains_preds(data_cv['d_pred'], d_input)
data_cv = pd.concat([
data_cv.reset_index(drop=True),
pred.reset_index(drop=True),
], axis=1)

self.data_cv = data_cv
self.bin_cv = bin_cv

Expand Down
1 change: 0 additions & 1 deletion src/madml/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
)

from matplotlib import pyplot as pl
from madml.models import domain
from madml import calculators
from sklearn import metrics

Expand Down

0 comments on commit 6c103f8

Please sign in to comment.