@@ -88,7 +88,7 @@ def show_confusion_matrix(sc, sqlContext, TP, TN, FP, FN):
8888def experiment_maker (experiment_id ):
8989 """
9090 a function that creates a new experiment if "experiment_name" doesn't exist
91- or will use the current one if it already does
91+ or will use the current one if it already does
9292 :param experiment_id the experiment name you would like to get or create
9393 """
9494 print ("Tracking Path " + mlflow .get_tracking_uri ())
@@ -115,7 +115,7 @@ class ModelEvaluator(object):
115115 A Function that provides an easy way to evaluate models once, or over random iterations
116116 """
117117
118- def __init__ (self , sqlContext , label_column = 'label' , prediction_column = 'prediction' , confusion_matrix = True ):
118+ def __init__ (self , label_column = 'label' , prediction_column = 'prediction' , confusion_matrix = True ):
119119 """
120120 :param sc: Spark Context
121121 :param sqlContext: SQLContext
@@ -127,11 +127,21 @@ def __init__(self, sqlContext, label_column='label', prediction_column='predicti
127127 self .avg_tn = []
128128 self .avg_fn = []
129129 self .avg_fp = []
130- self .sqlContext = sqlContext
130+ self .sqlContext = None
131+ self .sc = None
131132 self .label_column = label_column
132133 self .prediction_column = prediction_column
133134 self .confusion_matrix = confusion_matrix
134135
136+ def setup_contexts (self , sc , sqlContext ):
137+ """
138+ Setup contexts for ModelEvaluator
139+ :param sc: spark context
140+ :param sqlContext: sql context
141+ """
142+ self .sc = sc
143+ self .sqlContext = sqlContext
144+
135145 def input (self , predictions_dataframe ):
136146 """
137147 Evaluate actual vs Predicted in a dataframe
@@ -151,14 +161,15 @@ def input(self, predictions_dataframe):
151161 pred_v_lab [(pred_v_lab .label == 0 ) & (pred_v_lab .prediction == 1 )].count ())
152162
153163 if self .confusion_matrix :
154- show_confusion_matrix (self .avg_tp [- 1 ], self .avg_tn [- 1 ], self .avg_fp [- 1 ],
155- self .avg_fn [- 1 ]) # show the confusion matrix to the user
164+ show_confusion_matrix (self .sc , self .sqlContext , self .avg_tp [- 1 ],
165+ self .avg_tn [- 1 ], self .avg_fp [- 1 ], self .avg_fn [- 1 ])
166+ # show the confusion matrix to the user
156167
157168 def get_results (self , output_type = 'dataframe' ):
158169 """
159170 Return a dictionary containing evaluated results
160- :param output_type: either a dataframe or a dict (which to return)
161- :return results: computed_metrics (dict) or computed_df (df)
171+ :param output_type: either a dataframe or a dict (which to return)
172+ :return results: computed_metrics (dict) or computed_df (df)
162173 """
163174
164175 TP = np .mean (self .avg_tp )
0 commit comments