From 6c103f80bd42bbf149daf0c34db712a0ef943aaf Mon Sep 17 00:00:00 2001 From: leschultz Date: Tue, 30 Jan 2024 10:13:18 -0600 Subject: [PATCH] Now feel good about confusion plots --- setup.py | 2 +- src/madml/calculators.py | 1 - src/madml/models.py | 8 +++++++- src/madml/plots.py | 1 - 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index d8085c4..7f6878f 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ # Package information name = 'madml' -version = '2.1.0' # Need to increment every time to push to PyPI +version = '2.1.2' # Need to increment every time to push to PyPI description = 'Application domain of machine learning in materials science.' url = 'https://github.com/leschultz/'\ 'materials_application_domain_machine_learning.git' diff --git a/src/madml/calculators.py b/src/madml/calculators.py index 9517f03..b94567f 100644 --- a/src/madml/calculators.py +++ b/src/madml/calculators.py @@ -172,7 +172,6 @@ def pr(d, labels, precs): # Maximum F1 score max_f1_index = np.argmax(f1_scores) - print(f1_scores, max_f1_index) data['Max F1'] = { 'Precision': precision[max_f1_index], diff --git a/src/madml/models.py b/src/madml/models.py index a3e429e..1bc99eb 100644 --- a/src/madml/models.py +++ b/src/madml/models.py @@ -406,7 +406,7 @@ def cv(self, split, gs_model, ds_model, X, y, g=None): return data - def fit(self, X, y, g=None): + def fit(self, X, y, g=None, d_input=None): ''' Fit all models. Thresholds for domain classification are also set. @@ -506,6 +506,12 @@ def fit(self, X, y, g=None): bin_cv['domain_cdf_area'].values, ) + pred = self.combine_domains_preds(data_cv['d_pred'], d_input) + data_cv = pd.concat([ + data_cv.reset_index(drop=True), + pred.reset_index(drop=True), + ], axis=1) + self.data_cv = data_cv self.bin_cv = bin_cv diff --git a/src/madml/plots.py b/src/madml/plots.py index daf307c..bafdf66 100644 --- a/src/madml/plots.py +++ b/src/madml/plots.py @@ -5,7 +5,6 @@ ) from matplotlib import pyplot as pl -from madml.models import domain from madml import calculators from sklearn import metrics