Skip to content

Commit 4c57f19

Browse files
committed
Added manual threhold choce and update an asample
1 parent 448fb07 commit 4c57f19

File tree

2 files changed

+100
-26
lines changed

2 files changed

+100
-26
lines changed

examples/jupyter/tutorial_1.ipynb

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,14 @@
253253
"id": "HCL2WngJWphB"
254254
},
255255
"source": [
256-
"# Assessing and Fitting the Model\n",
257-
"Now, we assess the model through cross valication and then fit a final model on all data that can be used by a researcher. The assessment of the model is saved in a directory of the user's choice."
256+
"# Assessing and Fitting the Model"
257+
]
258+
},
259+
{
260+
"cell_type": "markdown",
261+
"metadata": {},
262+
"source": [
263+
"We can fit a single model without assessment, which is faster because of no nested cross validation. However, overfitting may occur."
258264
]
259265
},
260266
{
@@ -266,6 +272,22 @@
266272
"outputs": [],
267273
"source": [
268274
"model = domain_model(gs_model, ds_model, uq_model, splits)\n",
275+
"model.fit(X, y, g)"
276+
]
277+
},
278+
{
279+
"cell_type": "markdown",
280+
"metadata": {},
281+
"source": [
282+
"We can assess the model through cross validation and then fit a final model on all data that can be used by a researcher. The assessment of the model is saved in a directory of the user's choice. The fitting from the previous step is automatically performed after the nested cross validation to provide a usable model."
283+
]
284+
},
285+
{
286+
"cell_type": "code",
287+
"execution_count": null,
288+
"metadata": {},
289+
"outputs": [],
290+
"source": [
269291
"cv = nested_cv(X, y, g, model, splits, save='./runs')\n",
270292
"_, model = cv.assess()"
271293
]
@@ -291,6 +313,23 @@
291313
"df = model.predict(X)\n",
292314
"print(df)"
293315
]
316+
},
317+
{
318+
"cell_type": "markdown",
319+
"metadata": {},
320+
"source": [
321+
"Maybe the predefined thresholds for domain are insufficient. We can instead use some manual thresholds as a list of douples with <('dissimilarity measure', 'domain of id or od', 'threshold')> as follows:"
322+
]
323+
},
324+
{
325+
"cell_type": "code",
326+
"execution_count": null,
327+
"metadata": {},
328+
"outputs": [],
329+
"source": [
330+
"df = model.predict(X, [('dist', 'id', 0.75), ('dist', 'id_bin', 0.2)])\n",
331+
"print(df)"
332+
]
294333
}
295334
],
296335
"metadata": {

src/madml/models/combine.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,14 @@ def domain_pred(dist, dist_cut, domain):
4444
return do_pred
4545

4646

47-
def domain_preds(pred, dists, methods, thresholds, suffix=''):
47+
def domain_preds(
48+
pred,
49+
dists=None,
50+
methods=None,
51+
thresholds=None,
52+
suffix='',
53+
manual_th=None,
54+
):
4855
'''
4956
The domain predictor based on thresholds.
5057
@@ -59,27 +66,46 @@ def domain_preds(pred, dists, methods, thresholds, suffix=''):
5966
pred = The predictions of domain.
6067
'''
6168

62-
for i in dists:
69+
if manual_th is not None:
70+
for i in manual_th:
6371

64-
if suffix:
65-
col = i+suffix
66-
else:
67-
col = i
72+
dist, domain_type, th = i
73+
74+
if 'id' in domain_type:
75+
domain = True
76+
elif 'od' in domain_type:
77+
domain = False
6878

69-
for j, k in zip([True, False], ['id', 'od']):
70-
for method in methods:
71-
k += method
72-
for key, value in thresholds[i][k].items():
79+
name = '{} by {} for {}'.format(domain_type.upper(), dist, th)
80+
do_pred = domain_pred(
81+
pred[dist],
82+
th,
83+
domain,
84+
)
85+
86+
pred[name] = do_pred
87+
else:
88+
for i in dists:
7389

74-
name = '{} by {} for {}'.format(k.upper(), i, key)
90+
if suffix:
91+
col = i+suffix
92+
else:
93+
col = i
7594

76-
do_pred = domain_pred(
77-
pred[col],
78-
value['Threshold'],
79-
j,
80-
)
95+
for j, k in zip([True, False], ['id', 'od']):
96+
for method in methods:
97+
k += method
98+
for key, value in thresholds[i][k].items():
8199

82-
pred[name] = do_pred
100+
name = '{} by {} for {}'.format(k.upper(), i, key)
101+
102+
do_pred = domain_pred(
103+
pred[col],
104+
value['Threshold'],
105+
j,
106+
)
107+
108+
pred[name] = do_pred
83109

84110
return pred
85111

@@ -359,12 +385,13 @@ def fit(self, X, y, g):
359385

360386
return data_cv, data_cv_bin
361387

362-
def predict(self, X):
388+
def predict(self, X, manual_th=None):
363389
'''
364390
Give domain classification along with other regression predictions.
365391
366392
inputs:
367393
X = The features.
394+
manual_th = Manual thresholds.
368395
369396
outputs:
370397
pred = A pandas dataframe containing prediction data.
@@ -400,11 +427,19 @@ def predict(self, X):
400427
pred['y_stdc/std(y)'] = y_stdc_norm
401428

402429
pred = pd.DataFrame(pred)
403-
pred = domain_preds(
404-
pred,
405-
self.dists,
406-
self.methods,
407-
self.thresholds,
408-
)
430+
431+
if manual_th is None:
432+
pred = domain_preds(
433+
pred,
434+
self.dists,
435+
self.methods,
436+
self.thresholds,
437+
)
438+
else:
439+
440+
pred = domain_preds(
441+
pred,
442+
manual_th=manual_th
443+
)
409444

410445
return pred

0 commit comments

Comments
 (0)