Skip to content

Commit

Permalink
Adding residual domain
Browse files Browse the repository at this point in the history
  • Loading branch information
leschultz committed Apr 30, 2024
1 parent a46e88c commit 14b0af7
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/madml/assess.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
self.n_jobs = n_jobs

# If user defined
self.gt_absres = self.model.gt_absres
self.gt_rmse = self.model.gt_rmse
self.gt_area = self.model.gt_area

Expand Down Expand Up @@ -130,6 +131,7 @@ def cv(self, split, save_inner_folds=None):
data['y_stdc_pred/std_y'] = data['y_stdc_pred']/data['std_y']

# Ground truths
data['gt_absres'] = model.gt_absres
data['gt_rmse'] = model.gt_rmse
data['gt_area'] = model.gt_area

Expand Down
3 changes: 3 additions & 0 deletions src/madml/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ def ground_truth(self, y):
std_y = np.std(y)
mean_y = np.mean(y)

if self.gt_absres is None:
self.gt_absres = 1.0

if self.gt_rmse is None:
mean = np.repeat(mean_y, y.shape[0])
naive_rmse = mean_squared_error(
Expand Down
23 changes: 20 additions & 3 deletions src/madml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,20 +336,24 @@ def assign_ground_truth(data_cv, bin_cv):
data_cv = data_cv.merge(bin_cv, on=['bin'])

# Innitiate arrays
cols = ['gt_rmse', 'gt_area']
cols = ['gt_absres', 'gt_rmse', 'gt_area']
for c in cols:
bin_cv[c] = None

# Propagate ground truths
for group, value in data_cv.groupby(['bin', *cols], observed=True):
row = bin_cv['bin'] == group[0]
bin_cv.loc[row, 'gt_rmse'] = group[1]
bin_cv.loc[row, 'gt_area'] = group[2]

bin_cv.loc[row, 'gt_absres'] = group[1]
bin_cv.loc[row, 'gt_rmse'] = group[2]
bin_cv.loc[row, 'gt_area'] = group[3]

# Make labels
absres = data_cv['|r|/mad_y'] <= data_cv['gt_absres']
rmse = data_cv['rmse/std_y'] <= data_cv['gt_rmse']
area = data_cv['cdf_area'] <= data_cv['gt_area']

data_cv['domain_absres/mad_y'] = np.where(absres, 'ID', 'OD')
data_cv['domain_rmse/std_y'] = np.where(rmse, 'ID', 'OD')
data_cv['domain_cdf_area'] = np.where(area, 'ID', 'OD')

Expand All @@ -376,6 +380,7 @@ def __init__(
splits=[('fit', RepeatedKFold(n_repeats=2))],
bins=10,
precs=[],
gt_absres=None,
gt_rmse=None,
gt_area=None,
disable_tqdm=False,
Expand All @@ -389,6 +394,7 @@ def __init__(
splits = The list of splitting generators.
bins = The number of quantailes for binning data.
precs = The minimum preicisions for domain model.
gt_absres = The ground truth for absoulte residuals.
gt_rmse = The ground truth for rmse.
gt_area = The ground truth for miscalibration area.
'''
Expand All @@ -397,6 +403,7 @@ def __init__(
self.ds_model = ds_model
self.uq_model = uq_model
self.bins = bins
self.gt_absres = gt_absres
self.gt_rmse = gt_rmse
self.gt_area = gt_area
self.splits = copy.deepcopy(splits)
Expand Down Expand Up @@ -560,6 +567,7 @@ def fit(self, X, y, g=None, d_input=None, n_jobs=-1):

# Acquire ground truths
self = ground_truth(self, y)
data_cv['gt_absres'] = self.gt_absres
data_cv['gt_rmse'] = self.gt_rmse
data_cv['gt_area'] = self.gt_area

Expand All @@ -570,10 +578,15 @@ def fit(self, X, y, g=None, d_input=None, n_jobs=-1):
)

# Fit domain classifiers
self.domain_absres = domain(self.precs)
self.domain_rmse = domain(self.precs)
self.domain_area = domain(self.precs)

# Train classifiers
self.domain_absres.fit(
data_cv['d_pred'].values,
data_cv['domain_absres/mad_y'].values,
)
self.domain_rmse.fit(
data_cv['d_pred_max'].values,
data_cv['domain_rmse/std_y'].values,
Expand All @@ -599,13 +612,17 @@ def combine_domains_preds(self, d, d_input=None):
'''

# Predict domains on training data
data_absres_dom_pred = self.domain_absres.predict(d, d_input)
data_absres_dom_pred = data_absres_dom_pred.add_prefix('absres/mad_y ')

data_rmse_dom_pred = self.domain_rmse.predict(d, d_input)
data_rmse_dom_pred = data_rmse_dom_pred.add_prefix('rmse/std_y ')

data_area_dom_pred = self.domain_area.predict(d, d_input)
data_area_dom_pred = data_area_dom_pred.add_prefix('cdf_area ')

dom_pred = pd.concat([
data_absres_dom_pred,
data_rmse_dom_pred,
data_area_dom_pred,
], axis=1)
Expand Down
2 changes: 1 addition & 1 deletion src/madml/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def sub(x, y, ylabel, key, gt, gtlabel, metric, color):

indx = np.argsort(value)
d = value[indx]
y = sorters[r'$|y-\hat{y}|$'][indx]
y = df['|r|/std_y'].values[indx]
z = df['z'].values[indx]

out = parallel(
Expand Down

0 comments on commit 14b0af7

Please sign in to comment.