Skip to content

Commit 4291594

Browse files
committed
Added left out uncertaintiesw for fitting UQ that is not used in assessing a model.
1 parent a2591ef commit 4291594

File tree

10 files changed

+22
-12
lines changed

10 files changed

+22
-12
lines changed

examples/single_runs/bw_rf/template/fit.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def main():
7171
model = domain_model(gs_model, ds_model, uq_model, splits)
7272
cv = nested_cv(X, y, g, model, splits, save=run_name)
7373
cv.assess()
74-
cv.push('leschultz/cmg:{}'.format(data_name))
7574

7675

7776
if __name__ == '__main__':

examples/single_runs/gt_rf/make_runs.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ sets=(
99
"super_cond"
1010
)
1111

12-
gtgrid=(0.01 0.05 0.1 0.5 1.0 5.0 10.0 50.0)
12+
gtgrid=(0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 0.09 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0)
1313

1414
for i in "${sets[@]}"
1515
do

examples/single_runs/gt_rf/template/fit.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def main():
7171
model = domain_model(gs_model, ds_model, uq_model, splits, gts, gtb)
7272
cv = nested_cv(X, y, g, model, splits, save=run_name)
7373
cv.assess()
74-
cv.push('leschultz/cmg:{}'.format(data_name))
7574

7675

7776
if __name__ == '__main__':

examples/single_runs/kernel_rf/template/fit.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def main():
7171
model = domain_model(gs_model, ds_model, uq_model, splits)
7272
cv = nested_cv(X, y, g, model, splits, save=run_name)
7373
cv.assess()
74-
cv.push('leschultz/cmg:{}'.format(data_name))
7574

7675

7776
if __name__ == '__main__':

examples/single_runs/nn/template/fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def main():
111111
model = domain_model(gs_model, ds_model, uq_model, splits)
112112
cv = nested_cv(X, y, g, model, splits, save=run_name)
113113
cv.assess()
114-
cv.push('leschultz/cmg:{}'.format(data_name))
114+
cv.push('leschultz/cmg-nn-{}:latest'.format(data_name))
115115

116116

117117
if __name__ == '__main__':

examples/single_runs/ols/template/fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def main():
7272
model = domain_model(gs_model, ds_model, uq_model, splits)
7373
cv = nested_cv(X, y, g, model, splits, save=run_name)
7474
cv.assess()
75-
cv.push('leschultz/cmg:{}'.format(data_name))
75+
cv.push('leschultz/cmg-ols-{}:latest'.format(data_name))
7676

7777

7878
if __name__ == '__main__':

examples/single_runs/rf/template/fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def main():
7171
model = domain_model(gs_model, ds_model, uq_model, splits)
7272
cv = nested_cv(X, y, g, model, splits, save=run_name)
7373
cv.assess()
74-
cv.push('leschultz/cmg:{}'.format(data_name))
74+
cv.push('leschultz/cmg-rf-{}:latest'.format(data_name))
7575

7676

7777
if __name__ == '__main__':

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Package information
44
name = 'madml'
5-
version = '0.6.3' # Need to increment every time to push to PyPI
5+
version = '0.6.4' # Need to increment every time to push to PyPI
66
description = 'Application domain of machine learning in materials science.'
77
url = 'https://github.com/leschultz/'\
88
'materials_application_domain_machine_learning.git'

src/madml/models/combine.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(
116116
self.uq_model = uq_model
117117
self.bins = bins
118118
self.save = save
119-
self.splits = splits
119+
self.splits = copy.deepcopy(splits)
120120
self.gts = gts
121121
self.gtb = gtb
122122

@@ -125,6 +125,17 @@ def __init__(
125125
if self.uq_model:
126126
self.dists.append('y_stdc/std(y)')
127127
self.methods.append('_bin')
128+
129+
# Add a splitter to calibrate UQ and prevent overfitting
130+
uqsplits = []
131+
for i in self.splits:
132+
if 'calibration' == i[0]:
133+
i = copy.deepcopy(i)
134+
i = ('fit', i[1])
135+
uqsplits.append(i)
136+
137+
self.splits += uqsplits
138+
128139
if self.ds_model:
129140
self.dists.append('dist')
130141

@@ -269,14 +280,16 @@ def fit(self, X, y, g):
269280
data_cv['y/std(y)'] = data_cv['y']/data_cv['std(y)']
270281
data_cv['id'] = abs(data_cv['r/std(y)']) < self.gts # Ground truth
271282

272-
# Fit on hold out data ID
283+
# Fit UQ on hold out data ID
273284
if self.uq_model:
274-
data_id = data_cv[data_cv['splitter'] == 'calibration']
285+
data_id = data_cv[data_cv['splitter'] == 'fit']
275286
self.uq_model.fit(
276287
data_id['y'].values,
277288
data_id['y_pred'].values,
278289
data_id['y_stdu'].values
279290
)
291+
292+
data_cv = data_cv[data_cv['splitter'] != 'fit']
280293
data_cv['y_stdc'] = self.uq_model.predict(data_cv['y_stdu'].values)
281294
data_cv['y_stdc/std(y)'] = data_cv['y_stdc']/data_cv['std(y)']
282295
data_cv['z'] = data_cv['r']/data_cv['y_stdc']

src/madml/plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def generate_plots(data_cv, ystd, bins, save, gts, gtb, dists):
7272
)
7373

7474
# For each splitter of data
75-
for split, values in data_cv.groupby(['splitter']):
75+
for split, values in data_cv.groupby('splitter'):
7676

7777
sub = '{}'.format(split)
7878
parity(

0 commit comments

Comments
 (0)