Skip to content

Commit

Permalink
Added left out uncertaintiesw for fitting UQ that is not used in asse…
Browse files Browse the repository at this point in the history
…ssing a model.
  • Loading branch information
leschultz committed Sep 7, 2023
1 parent a2591ef commit 4291594
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 12 deletions.
1 change: 0 additions & 1 deletion examples/single_runs/bw_rf/template/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def main():
model = domain_model(gs_model, ds_model, uq_model, splits)
cv = nested_cv(X, y, g, model, splits, save=run_name)
cv.assess()
cv.push('leschultz/cmg:{}'.format(data_name))


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/single_runs/gt_rf/make_runs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ sets=(
"super_cond"
)

gtgrid=(0.01 0.05 0.1 0.5 1.0 5.0 10.0 50.0)
gtgrid=(0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 0.09 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0)

for i in "${sets[@]}"
do
Expand Down
1 change: 0 additions & 1 deletion examples/single_runs/gt_rf/template/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def main():
model = domain_model(gs_model, ds_model, uq_model, splits, gts, gtb)
cv = nested_cv(X, y, g, model, splits, save=run_name)
cv.assess()
cv.push('leschultz/cmg:{}'.format(data_name))


if __name__ == '__main__':
Expand Down
1 change: 0 additions & 1 deletion examples/single_runs/kernel_rf/template/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def main():
model = domain_model(gs_model, ds_model, uq_model, splits)
cv = nested_cv(X, y, g, model, splits, save=run_name)
cv.assess()
cv.push('leschultz/cmg:{}'.format(data_name))


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/single_runs/nn/template/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def main():
model = domain_model(gs_model, ds_model, uq_model, splits)
cv = nested_cv(X, y, g, model, splits, save=run_name)
cv.assess()
cv.push('leschultz/cmg:{}'.format(data_name))
cv.push('leschultz/cmg-nn-{}:latest'.format(data_name))


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/single_runs/ols/template/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main():
model = domain_model(gs_model, ds_model, uq_model, splits)
cv = nested_cv(X, y, g, model, splits, save=run_name)
cv.assess()
cv.push('leschultz/cmg:{}'.format(data_name))
cv.push('leschultz/cmg-ols-{}:latest'.format(data_name))


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/single_runs/rf/template/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main():
model = domain_model(gs_model, ds_model, uq_model, splits)
cv = nested_cv(X, y, g, model, splits, save=run_name)
cv.assess()
cv.push('leschultz/cmg:{}'.format(data_name))
cv.push('leschultz/cmg-rf-{}:latest'.format(data_name))


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Package information
name = 'madml'
version = '0.6.3' # Need to increment every time to push to PyPI
version = '0.6.4' # Need to increment every time to push to PyPI
description = 'Application domain of machine learning in materials science.'
url = 'https://github.com/leschultz/'\
'materials_application_domain_machine_learning.git'
Expand Down
19 changes: 16 additions & 3 deletions src/madml/models/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
self.uq_model = uq_model
self.bins = bins
self.save = save
self.splits = splits
self.splits = copy.deepcopy(splits)
self.gts = gts
self.gtb = gtb

Expand All @@ -125,6 +125,17 @@ def __init__(
if self.uq_model:
self.dists.append('y_stdc/std(y)')
self.methods.append('_bin')

# Add a splitter to calibrate UQ and prevent overfitting
uqsplits = []
for i in self.splits:
if 'calibration' == i[0]:
i = copy.deepcopy(i)
i = ('fit', i[1])
uqsplits.append(i)

self.splits += uqsplits

if self.ds_model:
self.dists.append('dist')

Expand Down Expand Up @@ -269,14 +280,16 @@ def fit(self, X, y, g):
data_cv['y/std(y)'] = data_cv['y']/data_cv['std(y)']
data_cv['id'] = abs(data_cv['r/std(y)']) < self.gts # Ground truth

# Fit on hold out data ID
# Fit UQ on hold out data ID
if self.uq_model:
data_id = data_cv[data_cv['splitter'] == 'calibration']
data_id = data_cv[data_cv['splitter'] == 'fit']
self.uq_model.fit(
data_id['y'].values,
data_id['y_pred'].values,
data_id['y_stdu'].values
)

data_cv = data_cv[data_cv['splitter'] != 'fit']
data_cv['y_stdc'] = self.uq_model.predict(data_cv['y_stdu'].values)
data_cv['y_stdc/std(y)'] = data_cv['y_stdc']/data_cv['std(y)']
data_cv['z'] = data_cv['r']/data_cv['y_stdc']
Expand Down
2 changes: 1 addition & 1 deletion src/madml/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def generate_plots(data_cv, ystd, bins, save, gts, gtb, dists):
)

# For each splitter of data
for split, values in data_cv.groupby(['splitter']):
for split, values in data_cv.groupby('splitter'):

sub = '{}'.format(split)
parity(
Expand Down

0 comments on commit 4291594

Please sign in to comment.