Skip to content

Commit c64a3d2

Browse files
committed
refactor decision tree model
1 parent 74e0709 commit c64a3d2

File tree

5 files changed

+21856
-21606
lines changed

5 files changed

+21856
-21606
lines changed

decision-tree.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Encoding: utf-8
2+
"""
3+
written by: Lawrence McDaniel
4+
https://lawrencemcdaniel.com
5+
6+
date: jun-2023
7+
8+
usage: minimalist implementation of Logistic Regression model.
9+
"""
10+
import os
11+
import warnings
12+
13+
# ------------------------------------------------------------------------------
14+
# IMPORTANT: DON'T FORGET TO INSTALL THESE LIBRARIES WITH pip
15+
# ------------------------------------------------------------------------------
16+
# Code to ignore warnings from function usage
17+
import pandas as pd
18+
import numpy as np
19+
import matplotlib.pyplot as plt
20+
import seaborn as sns
21+
22+
# Importing the Machine Learning models we require from Scikit-Learn
23+
from sklearn.tree import DecisionTreeClassifier
24+
from sklearn import tree
25+
from sklearn import metrics
26+
27+
from sklearn.model_selection import train_test_split, GridSearchCV
28+
from sklearn.metrics import confusion_matrix, classification_report, recall_score
29+
30+
# module initializations
31+
sns.set()
32+
HERE = os.path.abspath(os.path.dirname(__file__))
33+
warnings.filterwarnings("ignore")
34+
35+
36+
def metrics_score(actual, predicted):
37+
"""
38+
Create a common function for measuring the
39+
accuracy of both the train as well as test data.
40+
"""
41+
print("Metrics Score.")
42+
print(classification_report(actual, predicted))
43+
44+
cm = confusion_matrix(actual, predicted)
45+
plt.figure(figsize=(8, 5))
46+
47+
sns.heatmap(
48+
cm,
49+
annot=True,
50+
fmt=".2f",
51+
xticklabels=["Not Cancelled", "Cancelled"],
52+
yticklabels=["Not Cancelled", "Cancelled"],
53+
)
54+
plt.ylabel("Actual")
55+
plt.xlabel("Predicted")
56+
plt.show()
57+
58+
59+
def prepare_data():
60+
"""
61+
Raw database transformations:
62+
- clean the data
63+
- remove columns that don't contain any information
64+
- recast data types as necessary
65+
- convert categorical data into series of dummy columns
66+
- split dependent / independent variables
67+
- split training / test data sets
68+
"""
69+
print("Preparing data sets")
70+
original_db = pd.read_csv(os.path.join(HERE, "data", "reservations-db.csv"))
71+
72+
# need to be careful to only work with a **COPY** of the original
73+
# source data, lest we accidentally permanently modify any of this
74+
# raw data.
75+
data = original_db.copy()
76+
77+
# remove the ID column from the data set, since it contains
78+
# no predictive information.
79+
data = data.drop(["Booking_ID"], axis=1)
80+
81+
# recast dependent variable as boolean
82+
data["booking_status"] = data["booking_status"].apply(
83+
lambda x: 1 if x == "Canceled" else 0
84+
)
85+
86+
# hive off the dependent variable, "booking_status"
87+
x = data.drop(["booking_status"], axis=1)
88+
y = data["booking_status"]
89+
90+
# encode all categorical features
91+
x = pd.get_dummies(x, drop_first=True)
92+
93+
# Split data in train and test sets
94+
return train_test_split(x, y, test_size=0.30, stratify=y, random_state=1)
95+
96+
97+
def decision_tree():
98+
"""
99+
- create training and test data sets
100+
- create a Logistic Regression model
101+
- train the model
102+
- generate confusion matrix and f-score for the training set
103+
- generate confusion matrix and f-score for the test set
104+
"""
105+
print("Decision Tree")
106+
x_train, x_test, y_train, y_test = prepare_data()
107+
108+
# Linear Kernel or linear decision boundary
109+
print("- training")
110+
model_dt = DecisionTreeClassifier(class_weight={0: 0.17, 1: 0.83}, random_state=1)
111+
model_dt.fit(x_train, y_train)
112+
113+
print("- modeling on training data")
114+
pred_train_dt = model_dt.predict(x_train)
115+
metrics_score(y_train, pred_train_dt)
116+
117+
print("- modeling on test data")
118+
pred_test_dt = model_dt.predict(x_test)
119+
metrics_score(y_test, pred_test_dt)
120+
121+
# Metrics to evaluate the model
122+
123+
# Choose the type of classifier.
124+
estimator = DecisionTreeClassifier(class_weight={0: 0.17, 1: 0.83}, random_state=1)
125+
126+
# Grid of parameters to choose from
127+
parameters = {
128+
"max_depth": np.arange(2, 7, 2),
129+
"max_leaf_nodes": [50, 75, 150, 250],
130+
"min_samples_split": [10, 30, 50, 70],
131+
}
132+
scorer = metrics.make_scorer(recall_score, pos_label=1)
133+
134+
# Run the grid search
135+
grid_obj = GridSearchCV(estimator, parameters, scoring=scorer, cv=10)
136+
grid_obj = grid_obj.fit(x_train, y_train)
137+
138+
# Set the clf to the best combination of parameters
139+
estimator = grid_obj.best_estimator_
140+
141+
# Fit the best algorithm to the data.
142+
estimator.fit(x_train, y_train)
143+
144+
dt_tuned = estimator.predict(x_train)
145+
metrics_score(y_train, dt_tuned)
146+
147+
# Checking performance on the training dataset
148+
print("- remodeling on training data")
149+
y_pred_tuned = estimator.predict(x_test)
150+
metrics_score(y_test, y_pred_tuned)
151+
152+
# visualization of decision tree
153+
feature_names = list(x_train.columns)
154+
plt.figure(figsize=(20, 10))
155+
out = tree.plot_tree(
156+
estimator,
157+
max_depth=3,
158+
feature_names=feature_names,
159+
filled=True,
160+
fontsize=9,
161+
node_ids=False,
162+
class_names=None,
163+
)
164+
# below code will add arrows to the decision tree split if they are missing
165+
for o in out:
166+
arrow = o.arrow_patch
167+
if arrow is not None:
168+
arrow.set_edgecolor("black")
169+
arrow.set_linewidth(1)
170+
plt.show()
171+
172+
173+
if __name__ == "__main__":
174+
decision_tree()

0 commit comments

Comments
 (0)