From 14b0af795c622afc051213eb2c7a76c2ddd766c8 Mon Sep 17 00:00:00 2001 From: leschultz Date: Tue, 30 Apr 2024 16:00:49 -0500 Subject: [PATCH] Adding residual domain --- src/madml/assess.py | 2 ++ src/madml/calculators.py | 3 +++ src/madml/models.py | 23 ++++++++++++++++++++--- src/madml/plots.py | 2 +- 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/madml/assess.py b/src/madml/assess.py index 7eb5562..7b2d8ed 100644 --- a/src/madml/assess.py +++ b/src/madml/assess.py @@ -49,6 +49,7 @@ def __init__( self.n_jobs = n_jobs # If user defined + self.gt_absres = self.model.gt_absres self.gt_rmse = self.model.gt_rmse self.gt_area = self.model.gt_area @@ -130,6 +131,7 @@ def cv(self, split, save_inner_folds=None): data['y_stdc_pred/std_y'] = data['y_stdc_pred']/data['std_y'] # Ground truths + data['gt_absres'] = model.gt_absres data['gt_rmse'] = model.gt_rmse data['gt_area'] = model.gt_area diff --git a/src/madml/calculators.py b/src/madml/calculators.py index 6ad2cd8..d8b27d7 100644 --- a/src/madml/calculators.py +++ b/src/madml/calculators.py @@ -286,6 +286,9 @@ def ground_truth(self, y): std_y = np.std(y) mean_y = np.mean(y) + if self.gt_absres is None: + self.gt_absres = 1.0 + if self.gt_rmse is None: mean = np.repeat(mean_y, y.shape[0]) naive_rmse = mean_squared_error( diff --git a/src/madml/models.py b/src/madml/models.py index 9d711af..6c44c19 100644 --- a/src/madml/models.py +++ b/src/madml/models.py @@ -336,20 +336,24 @@ def assign_ground_truth(data_cv, bin_cv): data_cv = data_cv.merge(bin_cv, on=['bin']) # Innitiate arrays - cols = ['gt_rmse', 'gt_area'] + cols = ['gt_absres', 'gt_rmse', 'gt_area'] for c in cols: bin_cv[c] = None # Propagate ground truths for group, value in data_cv.groupby(['bin', *cols], observed=True): row = bin_cv['bin'] == group[0] - bin_cv.loc[row, 'gt_rmse'] = group[1] - bin_cv.loc[row, 'gt_area'] = group[2] + + bin_cv.loc[row, 'gt_absres'] = group[1] + bin_cv.loc[row, 'gt_rmse'] = group[2] + bin_cv.loc[row, 'gt_area'] = group[3] # Make labels + absres = data_cv['|r|/mad_y'] <= data_cv['gt_absres'] rmse = data_cv['rmse/std_y'] <= data_cv['gt_rmse'] area = data_cv['cdf_area'] <= data_cv['gt_area'] + data_cv['domain_absres/mad_y'] = np.where(absres, 'ID', 'OD') data_cv['domain_rmse/std_y'] = np.where(rmse, 'ID', 'OD') data_cv['domain_cdf_area'] = np.where(area, 'ID', 'OD') @@ -376,6 +380,7 @@ def __init__( splits=[('fit', RepeatedKFold(n_repeats=2))], bins=10, precs=[], + gt_absres=None, gt_rmse=None, gt_area=None, disable_tqdm=False, @@ -389,6 +394,7 @@ def __init__( splits = The list of splitting generators. bins = The number of quantailes for binning data. precs = The minimum preicisions for domain model. + gt_absres = The ground truth for absoulte residuals. gt_rmse = The ground truth for rmse. gt_area = The ground truth for miscalibration area. ''' @@ -397,6 +403,7 @@ def __init__( self.ds_model = ds_model self.uq_model = uq_model self.bins = bins + self.gt_absres = gt_absres self.gt_rmse = gt_rmse self.gt_area = gt_area self.splits = copy.deepcopy(splits) @@ -560,6 +567,7 @@ def fit(self, X, y, g=None, d_input=None, n_jobs=-1): # Acquire ground truths self = ground_truth(self, y) + data_cv['gt_absres'] = self.gt_absres data_cv['gt_rmse'] = self.gt_rmse data_cv['gt_area'] = self.gt_area @@ -570,10 +578,15 @@ def fit(self, X, y, g=None, d_input=None, n_jobs=-1): ) # Fit domain classifiers + self.domain_absres = domain(self.precs) self.domain_rmse = domain(self.precs) self.domain_area = domain(self.precs) # Train classifiers + self.domain_absres.fit( + data_cv['d_pred'].values, + data_cv['domain_absres/mad_y'].values, + ) self.domain_rmse.fit( data_cv['d_pred_max'].values, data_cv['domain_rmse/std_y'].values, @@ -599,6 +612,9 @@ def combine_domains_preds(self, d, d_input=None): ''' # Predict domains on training data + data_absres_dom_pred = self.domain_absres.predict(d, d_input) + data_absres_dom_pred = data_absres_dom_pred.add_prefix('absres/mad_y ') + data_rmse_dom_pred = self.domain_rmse.predict(d, d_input) data_rmse_dom_pred = data_rmse_dom_pred.add_prefix('rmse/std_y ') @@ -606,6 +622,7 @@ def combine_domains_preds(self, d, d_input=None): data_area_dom_pred = data_area_dom_pred.add_prefix('cdf_area ') dom_pred = pd.concat([ + data_absres_dom_pred, data_rmse_dom_pred, data_area_dom_pred, ], axis=1) diff --git a/src/madml/plots.py b/src/madml/plots.py index 4f96e5f..68538d5 100644 --- a/src/madml/plots.py +++ b/src/madml/plots.py @@ -176,7 +176,7 @@ def sub(x, y, ylabel, key, gt, gtlabel, metric, color): indx = np.argsort(value) d = value[indx] - y = sorters[r'$|y-\hat{y}|$'][indx] + y = df['|r|/std_y'].values[indx] z = df['z'].values[indx] out = parallel(