@@ -39,7 +39,8 @@ def handle_handlers(handler, *args, **kwargs):
3939 def log_metadata (self , handler , * args , ** kwargs ):
4040 if not self .run_uuid :
4141 with mlflow .start_run ():
42- self .run_uuid = (mlflow .active_run ().__dict__ ['_info' ].__dict__ ['_run_uuid' ])
42+ self .run_uuid = (mlflow .active_run ().__dict__ [
43+ '_info' ].__dict__ ['_run_uuid' ])
4344 print ("Logged using handler " + handler )
4445 Run .handle_handlers (handler , * args , ** kwargs )
4546 else :
@@ -77,7 +78,8 @@ def show_confusion_matrix(TP, TN, FP, FN):
7778 :param FN: False Negatives
7879 """
7980 confusion_rdd = sc .parallelize ([['Predicted' , TP , FN ], ['Actual' , FP , TN ]])
80- confusion_matrix = sqlContext .createDataFrame (confusion_rdd , ['' , 'Actual' , 'Predicted' ])
81+ confusion_matrix = sqlContext .createDataFrame (
82+ confusion_rdd , ['' , 'Actual' , 'Predicted' ])
8183 confusion_matrix .show ()
8284
8385
@@ -98,7 +100,8 @@ def experiment_maker(experiment_id):
98100 e ._experiment_id ) # use already created experiment
99101
100102 if not found :
101- _id = mlflow .tracking .create_experiment (experiment_id ) # create new experiment
103+ _id = mlflow .tracking .create_experiment (
104+ experiment_id ) # create new experiment
102105 print ('Success! Created Experiment' )
103106 os .environ ['MLFLOW_EXPERIMENT_ID' ] = str (_id ) # use it
104107 else :
@@ -135,7 +138,7 @@ def input(self, predictions_dataframe):
135138 self .prediction_column ) # Select the actual and the predicted labels
136139
137140 self .avg_tp .append (pred_v_lab [(pred_v_lab .label == 1 ) & (
138- pred_v_lab .prediction == 1 )].count ()) # Add confusion stats
141+ pred_v_lab .prediction == 1 )].count ()) # Add confusion stats
139142 self .avg_tn .append (
140143 pred_v_lab [(pred_v_lab .label == 0 ) & (pred_v_lab .prediction == 0 )].count ())
141144 self .avg_fp .append (
@@ -177,13 +180,14 @@ def get_results(self, output_type='dataframe'):
177180 return computed_metrics
178181
179182 else :
180- metrics_row = Row ('TPR' , 'SPC' , 'PPV' , 'NPV' , 'FPR' , 'FDR' , 'FNR' , 'ACC' , 'F1' , 'MCC' )
183+ metrics_row = Row ('TPR' , 'SPC' , 'PPV' , 'NPV' ,
184+ 'FPR' , 'FDR' , 'FNR' , 'ACC' , 'F1' , 'MCC' )
181185 computed_row = metrics_row (* computed_metrics .values ())
182186 computed_df = sqlContext .createDataFrame ([computed_row ])
183187 return computed_df
184188
185189
186- def print_horizantal_line (l ):
190+ def print_horizontal_line (l ):
187191 print ("" .join (['-' * l ]))
188192
189193
@@ -235,8 +239,9 @@ def visualize(model, feature_column_names, label_names, tree_name, visual=True):
235239 time .sleep (3 )
236240
237241 print ('You can find your visualization at "https://docs.google.com/gview?url=https'
238- '://<cluster_name>.splicemachine.io/assets/images/' + tree_name + '.pdf&embedded=tru'
239- 'e#view=fith' )
242+ '://<cluster_name>.splicemachine.io/assets/images/' +
243+ tree_name + '.pdf&embedded=tru'
244+ 'e#view=fith' )
240245
241246 @staticmethod
242247 def replacer (string , bad , good ):
@@ -269,10 +274,13 @@ def add_node(dot, parent, node_hash, root, realroot=False):
269274 dot .edge (node_hash , node_id )
270275 if root .get ('children' ):
271276 if not root ['children' ][0 ].get ('children' ):
272- DecisionTreeVisualizer .add_node (dot , root ['name' ], node_id , root ['children' ][0 ])
277+ DecisionTreeVisualizer .add_node (
278+ dot , root ['name' ], node_id , root ['children' ][0 ])
273279 else :
274- DecisionTreeVisualizer .add_node (dot , root ['name' ], node_id , root ['children' ][0 ])
275- DecisionTreeVisualizer .add_node (dot , root ['name' ], node_id , root ['children' ][1 ])
280+ DecisionTreeVisualizer .add_node (
281+ dot , root ['name' ], node_id , root ['children' ][0 ])
282+ DecisionTreeVisualizer .add_node (
283+ dot , root ['name' ], node_id , root ['children' ][1 ])
276284
277285 @staticmethod
278286 def parse (lines ):
@@ -285,12 +293,16 @@ def parse(lines):
285293 while lines :
286294
287295 if lines [0 ].startswith ('If' ):
288- bl = ' ' .join (lines .pop (0 ).split ()[1 :]).replace ('(' , '' ).replace (')' , '' )
289- block .append ({'name' : bl , 'children' : DecisionTreeVisualizer .parse (lines )})
296+ bl = ' ' .join (lines .pop (0 ).split ()[1 :]).replace (
297+ '(' , '' ).replace (')' , '' )
298+ block .append (
299+ {'name' : bl , 'children' : DecisionTreeVisualizer .parse (lines )})
290300
291301 if lines [0 ].startswith ('Else' ):
292- be = ' ' .join (lines .pop (0 ).split ()[1 :]).replace ('(' , '' ).replace (')' , '' )
293- block .append ({'name' : be , 'children' : DecisionTreeVisualizer .parse (lines )})
302+ be = ' ' .join (lines .pop (0 ).split ()[1 :]).replace (
303+ '(' , '' ).replace (')' , '' )
304+ block .append (
305+ {'name' : be , 'children' : DecisionTreeVisualizer .parse (lines )})
294306 elif not lines [0 ].startswith (('If' , 'Else' )):
295307 block2 = lines .pop (0 )
296308 block .append ({'name' : block2 })
@@ -314,5 +326,6 @@ def tree_json(tree):
314326 break
315327 if not line :
316328 break
317- res = [{'name' : 'Root' , 'children' : DecisionTreeVisualizer .parse (data [1 :])}]
329+ res = [
330+ {'name' : 'Root' , 'children' : DecisionTreeVisualizer .parse (data [1 :])}]
318331 return res [0 ]
0 commit comments