Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit d741fba

Browse files
authored
Update zeppelin.py
1 parent b8059de commit d741fba

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

splicemachine/ml/zeppelin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ModelEvaluator(object):
115115
A Function that provides an easy way to evaluate models once, or over random iterations
116116
"""
117117

118-
def __init__(self, sc, sqlContext, label_column='label', prediction_column='prediction', confusion_matrix=True):
118+
def __init__(self, sqlContext, label_column='label', prediction_column='prediction', confusion_matrix=True):
119119
"""
120120
:param sc: Spark Context
121121
:param sqlContext: SQLContext
@@ -127,7 +127,7 @@ def __init__(self, sc, sqlContext, label_column='label', prediction_column='pred
127127
self.avg_tn = []
128128
self.avg_fn = []
129129
self.avg_fp = []
130-
130+
self.sqlContext = sqlContext
131131
self.label_column = label_column
132132
self.prediction_column = prediction_column
133133
self.confusion_matrix = confusion_matrix
@@ -187,7 +187,7 @@ def get_results(self, output_type='dataframe'):
187187
metrics_row = Row('TPR', 'SPC', 'PPV', 'NPV',
188188
'FPR', 'FDR', 'FNR', 'ACC', 'F1', 'MCC')
189189
computed_row = metrics_row(*computed_metrics.values())
190-
computed_df = sqlContext.createDataFrame([computed_row])
190+
computed_df = self.sqlContext.createDataFrame([computed_row])
191191
return computed_df
192192

193193

0 commit comments

Comments
 (0)