-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c63493c
commit a0b1896
Showing
10 changed files
with
87 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,40 @@ | ||
import os | ||
import logging | ||
import pandas as pd | ||
from pathlib import Path | ||
from sklearn.preprocessing import StandardScaler | ||
from dotenv import find_dotenv, load_dotenv | ||
|
||
DATA_LINK = 'http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic3.csv' | ||
TITLES = ['Mlle', 'Mrs', 'Mr', 'Miss', 'Master', 'Don', 'Rev', 'Dr', 'Mme', 'Ms', 'Major', 'Col', 'Capt', 'Countess'] | ||
ROOT = Path(__file__).resolve().parents[2] | ||
|
||
|
||
def extract_title(name): | ||
title = 'missing' | ||
for item in TITLES: | ||
if item in name: | ||
title = item | ||
break | ||
if title == 'missing': | ||
title = 'Mr' | ||
return title | ||
|
||
|
||
def massage_data(raw_data): | ||
""" Preprocess the data for predictions | ||
""" | ||
preprocess the data for predictions | ||
""" | ||
# Feature engineering --- | ||
raw_data["title"] = raw_data.apply(lambda row: extract_title(row["name"]), axis=1) | ||
|
||
# Age: replace NaN with median | ||
raw_data["age"].fillna(raw_data.age.median(), inplace=True) | ||
|
||
# Embarked: replace NaN with the mode value | ||
raw_data["embarked"].fillna(raw_data.embarked.mode()[0], inplace=True) | ||
|
||
# Fare: replace NaN with median | ||
raw_data["fare"].fillna(raw_data.fare.median(), inplace=True) | ||
|
||
# Encode Categorical features --- | ||
raw_data["cabin"] = raw_data.apply(lambda obs: "No" if pd.isnull(obs['cabin']) else "Yes", axis=1) # binarize “cabin” feature | ||
raw_data = pd.get_dummies(raw_data, columns=['sex', 'title', 'cabin', 'embarked']) | ||
raw_data.rename(index=str, columns={"whether he/she donated blood in March 2007": "label"}, inplace=True) | ||
|
||
# generate features for year for time columns | ||
for x, y in zip(['time_years', 'recency_years'], ['Time (months)', 'Recency (months)']): | ||
raw_data[x] = (raw_data[y] / 12).astype('int') | ||
|
||
# generate features for quarter for time columns (3 month periods) | ||
for x, y in zip(['time_quarters', 'recency_quarters'], ['Time (months)', 'Recency (months)']): | ||
raw_data[x] = (raw_data[y] / 3).astype('int') | ||
|
||
# Scaling numerical features --- | ||
scale = StandardScaler().fit(raw_data[['age', 'fare']]) | ||
raw_data[['age', 'fare']] = scale.transform(raw_data[['age', 'fare']]) | ||
return raw_data | ||
|
||
|
||
def dump_data(data, out_loc): | ||
""" | ||
given a path to a datafile, either a local file path | ||
or a url, fetch the data and dump it to a csv | ||
""" | ||
out_dir = os.path.join(ROOT, out_loc) | ||
data.to_csv(out_dir, index=False) | ||
|
||
|
||
def main(): | ||
""" Runs data processing scripts to turn raw data from (../raw) into | ||
""" Retrieves data and runs processing scripts to turn raw data from (../raw) into | ||
cleaned data ready to be analyzed (saved in ../processed). | ||
""" | ||
raw_data = pd.read_csv(DATA_LINK) | ||
dump_data(raw_data, 'data/raw/titanic.csv') | ||
processed_data = massage_data(raw_data) | ||
dump_data(processed_data, 'data/processed/titanic.csv') | ||
df = pd.read_csv(ROOT / 'data/raw/transfusion_data_raw.csv') | ||
processed_data = massage_data(df) | ||
processed_data.to_csv(ROOT / 'data/processed/transfusion_data.csv', index=False) | ||
|
||
|
||
if __name__ == '__main__': | ||
log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||
logging.basicConfig(level=logging.INFO, format=log_fmt) | ||
|
||
load_dotenv(find_dotenv()) | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,38 @@ | ||
import os | ||
import pickle | ||
import logging | ||
import pandas as pd | ||
from pathlib import Path | ||
from sklearn.metrics import roc_auc_score | ||
from train_model import ROOT | ||
|
||
ROOT = Path(__file__).resolve().parents[2] | ||
|
||
|
||
def retrieve_model(): | ||
pickled_model = os.path.join(ROOT, 'models/titanic.model') | ||
"""retrieve the pickled model object | ||
""" | ||
pickled_model = ROOT / 'models/transfusion.model' | ||
with open(pickled_model, 'rb') as fin: | ||
return(pickle.load(fin)) | ||
|
||
|
||
def main(): | ||
""" retrieve the model and predict labels. Show prediction and performance | ||
""" | ||
deserialized_model = retrieve_model() | ||
X_test = pd.read_csv(os.path.join(ROOT, | ||
'data/processed/titanic_x_test.csv')) | ||
X_test = pd.read_csv(ROOT / 'data/processed/transfusion_x_test.csv') | ||
y_pred = deserialized_model.predict(X_test) | ||
|
||
y_test = pd.read_csv(os.path.join(ROOT, | ||
'data/processed/titanic_y_test.csv'), header=None) | ||
print(f'The model returned these predictions:\n{y_pred}') | ||
|
||
y_test = pd.read_csv(ROOT / 'data/processed/transfusion_y_test.csv', | ||
header=None) | ||
auc = roc_auc_score(y_test.astype(int), deserialized_model.predict_proba(X_test)[:, 1]) | ||
print('AUC (area under ROC curve): ' + str(auc)) | ||
return y_pred, auc | ||
|
||
|
||
if __name__ == '__main__': | ||
main() | ||
log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||
logging.basicConfig(level=logging.INFO, format=log_fmt) | ||
logger = logging.getLogger(__file__) | ||
|
||
preds, auc = main() | ||
logging.info('The predictions are {}'.format(preds)) | ||
logging.info('The AUC is {}'.format(auc)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,17 @@ | ||
import os | ||
import pytest | ||
import random | ||
import pandas as pd | ||
from pathlib import Path | ||
from src.data.make_dataset import extract_title, dump_data | ||
from src.data.make_dataset import massage_data | ||
|
||
ROOT = Path(__file__).resolve().parents[2] | ||
|
||
mock_data = { | ||
'whether he/she donated blood in March 2007': [1, 0, 0, 1], | ||
'Time (months)': [36, 10, 12, 16], | ||
'Recency (months)': [10, 20, 15, 22] | ||
} | ||
|
||
def test_extract_title(): | ||
names = ['Mr Bob', 'Mrs Daisy', 'Sam'] | ||
expected_maps = { | ||
'Mr Bob': 'Mr', | ||
'Mrs Daisy': 'Mrs', | ||
'Sam': 'Mr' | ||
} | ||
name = random.choice(names) | ||
title = extract_title(name) | ||
assert title == expected_maps.get(name) | ||
|
||
|
||
@pytest.mark.usefixtures('tmp_dump_dir') | ||
def test_dump_data(tmp_dump_dir, monkeypatch): | ||
def mock_path_join(*paths): | ||
return tmp_dump_dir | ||
monkeypatch.setattr(os.path, 'join', mock_path_join) | ||
df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) | ||
dump_data(df, 'foo') | ||
dumped = pd.read_csv(tmp_dump_dir) | ||
assert df.equals(dumped) | ||
|
||
|
||
|
||
def test_massage_data(): | ||
raw = pd.DataFrame(mock_data) | ||
data = massage_data(raw) | ||
assert data.iloc[0, 2] == 10 | ||
assert data.iloc[3, 6] == 7 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters