Skip to content

Commit 01811ed

Browse files
committed
warn if same data is used in transform as in fit
1 parent bab9bb2 commit 01811ed

File tree

2 files changed

+483
-0
lines changed

2 files changed

+483
-0
lines changed
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"pycharm": {
7+
"name": "#%% md\n"
8+
}
9+
},
10+
"source": [
11+
"From [pyvtreat issue 12](https://github.com/WinVector/pyvtreat/issues/12)"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 1,
17+
"metadata": {
18+
"pycharm": {
19+
"is_executing": false
20+
}
21+
},
22+
"outputs": [
23+
{
24+
"name": "stdout",
25+
"output_type": "stream",
26+
"text": [
27+
"model score: 0.880\n"
28+
]
29+
}
30+
],
31+
"source": [
32+
"import pandas as pd\n",
33+
"import numpy as np\n",
34+
"import numpy.random\n",
35+
"import vtreat\n",
36+
"import vtreat.util\n",
37+
"from sklearn.linear_model import LogisticRegression\n",
38+
"from sklearn.pipeline import Pipeline\n",
39+
"from sklearn.model_selection import train_test_split\n",
40+
"\n",
41+
"numpy.random.seed(2019)\n",
42+
"\n",
43+
"def make_data(nrows):\n",
44+
" d = pd.DataFrame({'x': 5*numpy.random.normal(size=nrows)})\n",
45+
" d['y'] = numpy.sin(d['x']) + 0.1*numpy.random.normal(size=nrows)\n",
46+
" d.loc[numpy.arange(3, 10), 'x'] = numpy.nan # introduce a nan level\n",
47+
" d['xc'] = ['level_' + str(5*numpy.round(yi/5, 1)) for yi in d['y']]\n",
48+
" d['x2'] = np.random.normal(size=nrows)\n",
49+
" d.loc[d['xc']=='level_-1.0', 'xc'] = numpy.nan # introduce a nan level\n",
50+
" d['yc'] = d['y']>0.5\n",
51+
" return d\n",
52+
"\n",
53+
"df = make_data(500)\n",
54+
"\n",
55+
"df = df.drop(columns=['y'])\n",
56+
"\n",
57+
"transform = vtreat.BinomialOutcomeTreatment(outcome_target=True)\n",
58+
"\n",
59+
"clf = Pipeline(steps=[\n",
60+
" ('preprocessor', transform),\n",
61+
" ('classifier', LogisticRegression(solver = 'lbfgs'))]\n",
62+
")\n",
63+
"\n",
64+
"X, y = df, df.pop('yc')\n",
65+
"\n",
66+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)\n",
67+
"\n",
68+
"clf.fit(X_train, y_train)\n",
69+
"\n",
70+
"print(\"model score: %.3f\" % clf.score(X_test, y_test))"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": 2,
76+
"metadata": {
77+
"collapsed": false,
78+
"jupyter": {
79+
"outputs_hidden": false
80+
},
81+
"pycharm": {
82+
"is_executing": false,
83+
"name": "#%%\n"
84+
}
85+
},
86+
"outputs": [
87+
{
88+
"name": "stderr",
89+
"output_type": "stream",
90+
"text": [
91+
"/Users/johnmount/opt/anaconda3/envs/ai_academy_3_7/lib/python3.7/site-packages/vtreat/vtreat_api.py:369: UserWarning: called transform on same data used to fit (this causes over-fit, please use fit_transform() instead)\n",
92+
" \"called transform on same data used to fit (this causes over-fit, please use fit_transform() instead)\")\n"
93+
]
94+
},
95+
{
96+
"data": {
97+
"text/plain": [
98+
"0.93"
99+
]
100+
},
101+
"execution_count": 2,
102+
"metadata": {},
103+
"output_type": "execute_result"
104+
}
105+
],
106+
"source": [
107+
"clf.score(X_train, y_train)"
108+
]
109+
},
110+
{
111+
"cell_type": "markdown",
112+
"metadata": {},
113+
"source": [
114+
"The above fit is an over-fit (not achievable without data leakage). Notice vtreat gave as a warning."
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": 3,
120+
"metadata": {
121+
"collapsed": false,
122+
"jupyter": {
123+
"outputs_hidden": false
124+
},
125+
"pycharm": {
126+
"is_executing": false,
127+
"name": "#%%\n"
128+
}
129+
},
130+
"outputs": [
131+
{
132+
"name": "stdout",
133+
"output_type": "stream",
134+
"text": [
135+
"Pipeline(memory=None,\n",
136+
" steps=[('preprocessor',\n",
137+
" vtreat.vtreat_api.BinomialOutcomeTreatment(outcome_target=True,\n",
138+
"params={'coders': {'clean_copy',\n",
139+
" 'deviation_code',\n",
140+
" 'impact_code',\n",
141+
" 'indicator_code',\n",
142+
" 'logit_code',\n",
143+
" 'missing_indicator',\n",
144+
" 'prevalence_code'},\n",
145+
" 'cross_validation_k': 5,\n",
146+
" 'cross_validation_plan': <vtreat.cross_plan.KWayCrossPlanYStratified object at 0x10fa81b50>,\n",
147+
" '...\n",
148+
" 'missingness_imputation': <function mean at 0x11093bb90>,\n",
149+
" 'sparse_indicators': True,\n",
150+
" 'use_hierarchical_estimate': True,\n",
151+
" 'user_transforms': []},\n",
152+
")),\n",
153+
" ('classifier',\n",
154+
" LogisticRegression(C=1.0, class_weight=None, dual=False,\n",
155+
" fit_intercept=True, intercept_scaling=1,\n",
156+
" l1_ratio=None, max_iter=100,\n",
157+
" multi_class='warn', n_jobs=None,\n",
158+
" penalty='l2', random_state=None,\n",
159+
" solver='lbfgs', tol=0.0001, verbose=0,\n",
160+
" warm_start=False))],\n",
161+
" verbose=False)\n"
162+
]
163+
}
164+
],
165+
"source": [
166+
"print(clf)"
167+
]
168+
},
169+
{
170+
"cell_type": "code",
171+
"execution_count": 4,
172+
"metadata": {
173+
"collapsed": false,
174+
"jupyter": {
175+
"outputs_hidden": false
176+
},
177+
"pycharm": {
178+
"is_executing": false,
179+
"name": "#%%\n"
180+
}
181+
},
182+
"outputs": [
183+
{
184+
"name": "stdout",
185+
"output_type": "stream",
186+
"text": [
187+
"['x_is_bad', 'xc_is_bad', 'x', 'x2', 'xc_logit_code', 'xc_prevalence_code', 'xc_lev_level_1_0', 'xc_lev__NA_', 'xc_lev_level_-0_5', 'xc_lev_level_0_5']\n"
188+
]
189+
}
190+
],
191+
"source": [
192+
"print(transform.get_feature_names())"
193+
]
194+
},
195+
{
196+
"cell_type": "code",
197+
"execution_count": 5,
198+
"metadata": {
199+
"collapsed": false,
200+
"jupyter": {
201+
"outputs_hidden": false
202+
},
203+
"pycharm": {
204+
"is_executing": false,
205+
"name": "#%%\n"
206+
}
207+
},
208+
"outputs": [
209+
{
210+
"name": "stdout",
211+
"output_type": "stream",
212+
"text": [
213+
"{'use_hierarchical_estimate': True, 'coders': {'prevalence_code', 'logit_code', 'indicator_code', 'deviation_code', 'impact_code', 'missing_indicator', 'clean_copy'}, 'filter_to_recommended': True, 'indicator_min_fraction': 0.1, 'cross_validation_plan': <vtreat.cross_plan.KWayCrossPlanYStratified object at 0x10fa81b50>, 'cross_validation_k': 5, 'user_transforms': [], 'sparse_indicators': True, 'missingness_imputation': <function mean at 0x11093bb90>, 'outcome_target': True}\n"
214+
]
215+
}
216+
],
217+
"source": [
218+
"print(transform.get_params())\n"
219+
]
220+
},
221+
{
222+
"cell_type": "code",
223+
"execution_count": 6,
224+
"metadata": {
225+
"collapsed": false,
226+
"jupyter": {
227+
"outputs_hidden": false
228+
},
229+
"pycharm": {
230+
"is_executing": false,
231+
"name": "#%%\n"
232+
}
233+
},
234+
"outputs": [
235+
{
236+
"name": "stdout",
237+
"output_type": "stream",
238+
"text": [
239+
"{'memory': None, 'steps': [('preprocessor', vtreat.vtreat_api.BinomialOutcomeTreatment(outcome_target=True,\n",
240+
"params={'coders': {'clean_copy',\n",
241+
" 'deviation_code',\n",
242+
" 'impact_code',\n",
243+
" 'indicator_code',\n",
244+
" 'logit_code',\n",
245+
" 'missing_indicator',\n",
246+
" 'prevalence_code'},\n",
247+
" 'cross_validation_k': 5,\n",
248+
" 'cross_validation_plan': <vtreat.cross_plan.KWayCrossPlanYStratified object at 0x10fa81b50>,\n",
249+
" 'filter_to_recommended': True,\n",
250+
" 'indicator_min_fraction': 0.1,\n",
251+
" 'missingness_imputation': <function mean at 0x11093bb90>,\n",
252+
" 'sparse_indicators': True,\n",
253+
" 'use_hierarchical_estimate': True,\n",
254+
" 'user_transforms': []},\n",
255+
")), ('classifier', LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
256+
" intercept_scaling=1, l1_ratio=None, max_iter=100,\n",
257+
" multi_class='warn', n_jobs=None, penalty='l2',\n",
258+
" random_state=None, solver='lbfgs', tol=0.0001, verbose=0,\n",
259+
" warm_start=False))], 'verbose': False, 'preprocessor': vtreat.vtreat_api.BinomialOutcomeTreatment(outcome_target=True,\n",
260+
"params={'coders': {'clean_copy',\n",
261+
" 'deviation_code',\n",
262+
" 'impact_code',\n",
263+
" 'indicator_code',\n",
264+
" 'logit_code',\n",
265+
" 'missing_indicator',\n",
266+
" 'prevalence_code'},\n",
267+
" 'cross_validation_k': 5,\n",
268+
" 'cross_validation_plan': <vtreat.cross_plan.KWayCrossPlanYStratified object at 0x10fa81b50>,\n",
269+
" 'filter_to_recommended': True,\n",
270+
" 'indicator_min_fraction': 0.1,\n",
271+
" 'missingness_imputation': <function mean at 0x11093bb90>,\n",
272+
" 'sparse_indicators': True,\n",
273+
" 'use_hierarchical_estimate': True,\n",
274+
" 'user_transforms': []},\n",
275+
"), 'classifier': LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
276+
" intercept_scaling=1, l1_ratio=None, max_iter=100,\n",
277+
" multi_class='warn', n_jobs=None, penalty='l2',\n",
278+
" random_state=None, solver='lbfgs', tol=0.0001, verbose=0,\n",
279+
" warm_start=False), 'preprocessor__use_hierarchical_estimate': True, 'preprocessor__coders': {'prevalence_code', 'logit_code', 'indicator_code', 'deviation_code', 'impact_code', 'missing_indicator', 'clean_copy'}, 'preprocessor__filter_to_recommended': True, 'preprocessor__indicator_min_fraction': 0.1, 'preprocessor__cross_validation_plan': <vtreat.cross_plan.KWayCrossPlanYStratified object at 0x10fa81b50>, 'preprocessor__cross_validation_k': 5, 'preprocessor__user_transforms': [], 'preprocessor__sparse_indicators': True, 'preprocessor__missingness_imputation': <function mean at 0x11093bb90>, 'preprocessor__outcome_target': True, 'classifier__C': 1.0, 'classifier__class_weight': None, 'classifier__dual': False, 'classifier__fit_intercept': True, 'classifier__intercept_scaling': 1, 'classifier__l1_ratio': None, 'classifier__max_iter': 100, 'classifier__multi_class': 'warn', 'classifier__n_jobs': None, 'classifier__penalty': 'l2', 'classifier__random_state': None, 'classifier__solver': 'lbfgs', 'classifier__tol': 0.0001, 'classifier__verbose': 0, 'classifier__warm_start': False}\n"
280+
]
281+
}
282+
],
283+
"source": [
284+
"print(clf.get_params())\n"
285+
]
286+
}
287+
],
288+
"metadata": {
289+
"kernelspec": {
290+
"display_name": "Python 3",
291+
"language": "python",
292+
"name": "python3"
293+
},
294+
"language_info": {
295+
"codemirror_mode": {
296+
"name": "ipython",
297+
"version": 3
298+
},
299+
"file_extension": ".py",
300+
"mimetype": "text/x-python",
301+
"name": "python",
302+
"nbconvert_exporter": "python",
303+
"pygments_lexer": "ipython3",
304+
"version": "3.7.5"
305+
},
306+
"pycharm": {
307+
"stem_cell": {
308+
"cell_type": "raw",
309+
"metadata": {
310+
"collapsed": false
311+
},
312+
"source": []
313+
}
314+
}
315+
},
316+
"nbformat": 4,
317+
"nbformat_minor": 4
318+
}

0 commit comments

Comments
 (0)