diff --git a/src/tweet_classifier_BERT.py b/src/tweet_classifier_BERT.py index ede2280..d7bb94c 100644 --- a/src/tweet_classifier_BERT.py +++ b/src/tweet_classifier_BERT.py @@ -212,21 +212,17 @@ def confusion(prediction, truth): - 0 and 0 (True Negative) - 0 and 1 (False Negative) """ + # Getting values prediction = np.argmax(prediction, axis=1).flatten() truth = truth.flatten() - confusion_vector = prediction / truth - # Element-wise division of the 2 arrays returns a new tensor which holds a - # unique value for each case: - # 1 where prediction and truth are 1 (True Positive) - # inf where prediction is 1 and truth is 0 (False Positive) - # nan where prediction and truth are 0 (True Negative) - # 0 where prediction is 0 and truth is 1 (False Negative) - - true_positives = np.sum(confusion_vector == 1) - false_positives = np.sum(confusion_vector == float('inf')) - true_negatives = np.sum(np.isnan(confusion_vector)) - false_negatives = np.sum(confusion_vector == 0) - + # Applying filtering + positives = np.where(prediction == 1) + negatives = np.where(prediction == 0) + # Building the confusion matrix + true_positives = np.size(np.where(truth[positives] == 1)) + false_positives = np.size(np.where(truth[positives] == 0)) + true_negatives = np.size(np.where(truth[negatives] == 0)) + false_negatives = np.size(np.where(truth[negatives] == 1)) return true_positives, false_positives, true_negatives, false_negatives