Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
lpm0073 committed Jun 29, 2023
1 parent daf8417 commit 9916a1e
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions logistic-regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,32 @@ def prepare_data():
return train_test_split(x, y, test_size=0.30, stratify=y, random_state=1)


def main():
def logistic_regression():
"""
- create training and test data sets
- create a Logistic Regression model
- train the model
- generate confusion matrix and f-score for the training set
- generate confusion matrix and f-score for the test set
"""
print("Prepare data")
x_train, x_test, y_train, y_test = prepare_data()

# Fit a logistic regression model
print("train model")
model = LogisticRegression()
model.fit(x_train, y_train)

# Set the optimal threshold (refer to the Jupyter Notebook to see how we arrived at 42)
optimal_threshold = 0.42

# Create a confusion matrix for the training data
print("model training data and measure results")
y_pred_train = model.predict_proba(x_train)
metrics_score(y_train, y_pred_train[:, 1] > optimal_threshold)

# Create a confusion matrix for the test data
print("model test data and measure results")
y_pred_test = model.predict_proba(x_test)
metrics_score(y_test, y_pred_test[:, 1] > optimal_threshold)


if __name__ == "__main__":
main()
logistic_regression()

0 comments on commit 9916a1e

Please sign in to comment.