Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit 9add1e7

Browse files
Amrit BavejaAmrit Baveja
authored andcommitted
added sc and sqlctx to show confusion matrix
1 parent d741fba commit 9add1e7

File tree

9 files changed

+18
-165
lines changed

9 files changed

+18
-165
lines changed

.idea/misc.xml

Lines changed: 0 additions & 4 deletions
This file was deleted.

.idea/modules.xml

Lines changed: 0 additions & 8 deletions
This file was deleted.

.idea/new-pysplice.iml

Lines changed: 0 additions & 12 deletions
This file was deleted.

.idea/vcs.xml

Lines changed: 0 additions & 6 deletions
This file was deleted.

.idea/workspace.xml

Lines changed: 0 additions & 128 deletions
This file was deleted.
148 Bytes
Binary file not shown.
151 Bytes
Binary file not shown.
10.6 KB
Binary file not shown.

splicemachine/ml/zeppelin.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def show_confusion_matrix(sc, sqlContext, TP, TN, FP, FN):
8888
def experiment_maker(experiment_id):
8989
"""
9090
a function that creates a new experiment if "experiment_name" doesn't exist
91-
or will use the current one if it already does
91+
or will use the current one if it already does
9292
:param experiment_id the experiment name you would like to get or create
9393
"""
9494
print("Tracking Path " + mlflow.get_tracking_uri())
@@ -115,7 +115,7 @@ class ModelEvaluator(object):
115115
A Function that provides an easy way to evaluate models once, or over random iterations
116116
"""
117117

118-
def __init__(self, sqlContext, label_column='label', prediction_column='prediction', confusion_matrix=True):
118+
def __init__(self, label_column='label', prediction_column='prediction', confusion_matrix=True):
119119
"""
120120
:param sc: Spark Context
121121
:param sqlContext: SQLContext
@@ -127,11 +127,21 @@ def __init__(self, sqlContext, label_column='label', prediction_column='predicti
127127
self.avg_tn = []
128128
self.avg_fn = []
129129
self.avg_fp = []
130-
self.sqlContext = sqlContext
130+
self.sqlContext = None
131+
self.sc = None
131132
self.label_column = label_column
132133
self.prediction_column = prediction_column
133134
self.confusion_matrix = confusion_matrix
134135

136+
def setup_contexts(self, sc, sqlContext):
137+
"""
138+
Setup contexts for ModelEvaluator
139+
:param sc: spark context
140+
:param sqlContext: sql context
141+
"""
142+
self.sc = sc
143+
self.sqlContext = sqlContext
144+
135145
def input(self, predictions_dataframe):
136146
"""
137147
Evaluate actual vs Predicted in a dataframe
@@ -151,14 +161,15 @@ def input(self, predictions_dataframe):
151161
pred_v_lab[(pred_v_lab.label == 0) & (pred_v_lab.prediction == 1)].count())
152162

153163
if self.confusion_matrix:
154-
show_confusion_matrix(self.avg_tp[-1], self.avg_tn[-1], self.avg_fp[-1],
155-
self.avg_fn[-1]) # show the confusion matrix to the user
164+
show_confusion_matrix(self.sc, self.sqlContext, self.avg_tp[-1],
165+
self.avg_tn[-1], self.avg_fp[-1], self.avg_fn[-1])
166+
# show the confusion matrix to the user
156167

157168
def get_results(self, output_type='dataframe'):
158169
"""
159170
Return a dictionary containing evaluated results
160-
:param output_type: either a dataframe or a dict (which to return)
161-
:return results: computed_metrics (dict) or computed_df (df)
171+
:param output_type: either a dataframe or a dict (which to return)
172+
:return results: computed_metrics (dict) or computed_df (df)
162173
"""
163174

164175
TP = np.mean(self.avg_tp)

0 commit comments

Comments
 (0)