-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New dictionary entry for confidence plots
- Loading branch information
Showing
13 changed files
with
222 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,4 @@ | |
export PYTHONPATH=$(pwd)/../../../../../../src:$PYTHONPATH | ||
|
||
rm -rf run | ||
python3 fit.py | ||
time python3 fit.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
#!/bin/bash | ||
|
||
data=( | ||
"friedman1" | ||
"fluence" | ||
"diffusion" | ||
"strength" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,4 @@ | |
export PYTHONPATH=$(pwd)/../../../../../src:$PYTHONPATH | ||
|
||
rm -rf run | ||
python3 fit.py | ||
time python3 fit.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,4 @@ | |
export PYTHONPATH=$(pwd)/../../../../../../src:$PYTHONPATH | ||
|
||
rm -rf run | ||
python3 fit.py | ||
time python3 fit.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#!/bin/bash | ||
|
||
data=( | ||
"friedman1" | ||
"fluence" | ||
"diffusion" | ||
"strength" | ||
"supercond" | ||
) | ||
|
||
models=( | ||
"rf" | ||
) | ||
|
||
for i in "${data[@]}" | ||
do | ||
|
||
for j in "${models[@]}" | ||
do | ||
|
||
echo "Making (data, model)=(${i}, ${j})" | ||
job_dir="runs/data_${i}/model_${j}" | ||
|
||
mkdir -p ${job_dir} | ||
cp -r template/* ${job_dir} | ||
cd ${job_dir} | ||
|
||
# Define the repeats | ||
if [ "${i}" == "fluence" ] && [ "${j}" == "bnn" ]; then | ||
r=3 | ||
elif [ "${i}" == "friedman1" ] && [ "${j}" == "bnn" ]; then | ||
r=3 | ||
elif [ "${i}" == "supercond" ] && [ "${j}" == "bnn" ]; then | ||
r=2 | ||
else | ||
r=5 | ||
fi | ||
|
||
sed -i "s/replace_data/'${i}'/g" fit.py | ||
sed -i "s/replace_model/'${j}'/g" fit.py | ||
sed -i "s/replace_repeats/${r}/g" fit.py | ||
|
||
cd - > /dev/null | ||
|
||
done | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#!/bin/bash | ||
|
||
submit=submit.sh | ||
for i in $(find ${1} -type f -name ${submit}) | ||
do | ||
cd $(dirname ${i}) | ||
qsub ${submit} | ||
cd - > /dev/null | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from sklearn.cluster import AgglomerativeClustering | ||
from sklearn.model_selection import RepeatedKFold | ||
from sklearn.model_selection import GridSearchCV | ||
from sklearn.preprocessing import MinMaxScaler | ||
from sklearn.pipeline import Pipeline | ||
|
||
from madml.models import dissimilarity, calibration, combine | ||
from madml.splitters import BootstrappedLeaveClusterOut | ||
from madml.assess import nested_cv | ||
from madml import datasets | ||
|
||
from mods import return_model | ||
|
||
|
||
def main(): | ||
|
||
run_name = 'output' | ||
data_name = replace_data | ||
model = replace_model | ||
|
||
# Load data | ||
data = datasets.load(data_name) | ||
X = data['data'] | ||
y = data['target'] | ||
|
||
# MADML parameters | ||
bins = 10 | ||
n_repeats = replace_repeats | ||
|
||
# ML Distance model | ||
ds_model = dissimilarity(dis='kde') | ||
|
||
# ML UQ function | ||
uq_model = calibration(params=[0.0, 1.0]) | ||
|
||
# ML | ||
scale = MinMaxScaler() | ||
model = return_model(model, X) | ||
|
||
# The grid for grid search | ||
grid = {} | ||
|
||
# The machine learning pipeline | ||
pipe = Pipeline(steps=[ | ||
('scaler', scale), | ||
('model', model), | ||
]) | ||
|
||
# The gridsearch model | ||
gs_model = GridSearchCV( | ||
pipe, | ||
grid, | ||
cv=((slice(None), slice(None)),), # No splits | ||
scoring='neg_mean_squared_error', | ||
) | ||
|
||
# Types of sampling to test | ||
splits = [('fit', RepeatedKFold(n_repeats=n_repeats, n_splits=5))] | ||
|
||
# Boostrap, cluster data, and generate splits | ||
for clusters in [2, 3]: | ||
|
||
# Cluster Splits | ||
top_split = BootstrappedLeaveClusterOut( | ||
AgglomerativeClustering, | ||
n_repeats=n_repeats, | ||
n_clusters=clusters, | ||
) | ||
|
||
splits.append(('agglo_{}'.format(clusters), top_split)) | ||
|
||
# Assess models | ||
model = combine(gs_model, ds_model, uq_model, splits, bins=bins) | ||
cv = nested_cv(model, X, y, splitters=splits) | ||
df, df_bin, fit_model = cv.test( | ||
save_outer_folds=run_name, | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from sklearn.ensemble import RandomForestRegressor | ||
from sklearn.linear_model import LinearRegression | ||
from sklearn.ensemble import BaggingRegressor | ||
from scikeras.wrappers import KerasRegressor | ||
from keras.layers import Dense, Dropout | ||
from keras.models import Sequential | ||
from sklearn.svm import SVR | ||
|
||
|
||
def return_model(name, X): | ||
|
||
if name == 'rf': | ||
return RandomForestRegressor(n_estimators=100) | ||
|
||
elif name == 'bols': | ||
return BaggingRegressor(LinearRegression(), n_estimators=100) | ||
|
||
elif name == 'bsvr': | ||
return BaggingRegressor(SVR(), n_estimators=100) | ||
|
||
elif name == 'bnn': | ||
model = KerasRegressor( | ||
build_fn=keras_model, | ||
shape=X.shape[1], | ||
epochs=500, | ||
batch_size=100, | ||
verbose=0, | ||
) | ||
|
||
return BaggingRegressor(model, n_estimators=10) | ||
|
||
else: | ||
raise 'No model matching name.' | ||
|
||
|
||
def keras_model(shape): | ||
|
||
n = 100 | ||
model = Sequential() | ||
model.add(Dense( | ||
n, | ||
input_dim=shape, | ||
kernel_initializer='normal', | ||
activation='relu' | ||
)) | ||
model.add(Dropout(0.3)) | ||
model.add(Dense( | ||
n, | ||
kernel_initializer='normal', | ||
activation='relu' | ||
)) | ||
model.add(Dropout(0.3)) | ||
model.add(Dense( | ||
1, | ||
kernel_initializer='normal' | ||
)) | ||
model.compile( | ||
loss='mean_squared_error', | ||
optimizer='adam' | ||
) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
|
||
export PYTHONPATH=$(pwd)/../../../../../src:$PYTHONPATH | ||
|
||
rm -rf run | ||
time python3 fit.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#PBS -S /bin/bash | ||
#PBS -q bardeen | ||
#PBS -l select=1:ncpus=16:mpiprocs=16 | ||
#PBS -l walltime=72:00:00 | ||
#PBS -N job | ||
|
||
cd $PBS_O_WORKDIR | ||
|
||
./run.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,4 @@ | |
export PYTHONPATH=$(pwd)/../../../../../src:$PYTHONPATH | ||
|
||
rm -rf run | ||
python3 fit.py | ||
time python3 fit.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters