Skip to content

Commit

Permalink
Added manual threhold choce and update an asample
Browse files Browse the repository at this point in the history
  • Loading branch information
leschultz committed Oct 2, 2023
1 parent 448fb07 commit 4c57f19
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 26 deletions.
43 changes: 41 additions & 2 deletions examples/jupyter/tutorial_1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,14 @@
"id": "HCL2WngJWphB"
},
"source": [
"# Assessing and Fitting the Model\n",
"Now, we assess the model through cross valication and then fit a final model on all data that can be used by a researcher. The assessment of the model is saved in a directory of the user's choice."
"# Assessing and Fitting the Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can fit a single model without assessment, which is faster because of no nested cross validation. However, overfitting may occur."
]
},
{
Expand All @@ -266,6 +272,22 @@
"outputs": [],
"source": [
"model = domain_model(gs_model, ds_model, uq_model, splits)\n",
"model.fit(X, y, g)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can assess the model through cross validation and then fit a final model on all data that can be used by a researcher. The assessment of the model is saved in a directory of the user's choice. The fitting from the previous step is automatically performed after the nested cross validation to provide a usable model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cv = nested_cv(X, y, g, model, splits, save='./runs')\n",
"_, model = cv.assess()"
]
Expand All @@ -291,6 +313,23 @@
"df = model.predict(X)\n",
"print(df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Maybe the predefined thresholds for domain are insufficient. We can instead use some manual thresholds as a list of douples with <('dissimilarity measure', 'domain of id or od', 'threshold')> as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = model.predict(X, [('dist', 'id', 0.75), ('dist', 'id_bin', 0.2)])\n",
"print(df)"
]
}
],
"metadata": {
Expand Down
83 changes: 59 additions & 24 deletions src/madml/models/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@ def domain_pred(dist, dist_cut, domain):
return do_pred


def domain_preds(pred, dists, methods, thresholds, suffix=''):
def domain_preds(
pred,
dists=None,
methods=None,
thresholds=None,
suffix='',
manual_th=None,
):
'''
The domain predictor based on thresholds.
Expand All @@ -59,27 +66,46 @@ def domain_preds(pred, dists, methods, thresholds, suffix=''):
pred = The predictions of domain.
'''

for i in dists:
if manual_th is not None:
for i in manual_th:

if suffix:
col = i+suffix
else:
col = i
dist, domain_type, th = i

if 'id' in domain_type:
domain = True
elif 'od' in domain_type:
domain = False

for j, k in zip([True, False], ['id', 'od']):
for method in methods:
k += method
for key, value in thresholds[i][k].items():
name = '{} by {} for {}'.format(domain_type.upper(), dist, th)
do_pred = domain_pred(
pred[dist],
th,
domain,
)

pred[name] = do_pred
else:
for i in dists:

name = '{} by {} for {}'.format(k.upper(), i, key)
if suffix:
col = i+suffix
else:
col = i

do_pred = domain_pred(
pred[col],
value['Threshold'],
j,
)
for j, k in zip([True, False], ['id', 'od']):
for method in methods:
k += method
for key, value in thresholds[i][k].items():

pred[name] = do_pred
name = '{} by {} for {}'.format(k.upper(), i, key)

do_pred = domain_pred(
pred[col],
value['Threshold'],
j,
)

pred[name] = do_pred

return pred

Expand Down Expand Up @@ -359,12 +385,13 @@ def fit(self, X, y, g):

return data_cv, data_cv_bin

def predict(self, X):
def predict(self, X, manual_th=None):
'''
Give domain classification along with other regression predictions.
inputs:
X = The features.
manual_th = Manual thresholds.
outputs:
pred = A pandas dataframe containing prediction data.
Expand Down Expand Up @@ -400,11 +427,19 @@ def predict(self, X):
pred['y_stdc/std(y)'] = y_stdc_norm

pred = pd.DataFrame(pred)
pred = domain_preds(
pred,
self.dists,
self.methods,
self.thresholds,
)

if manual_th is None:
pred = domain_preds(
pred,
self.dists,
self.methods,
self.thresholds,
)
else:

pred = domain_preds(
pred,
manual_th=manual_th
)

return pred

0 comments on commit 4c57f19

Please sign in to comment.