|
| 1 | +import os |
| 2 | +import random |
| 3 | +import time |
| 4 | + |
| 5 | +import graphviz |
| 6 | +import mlflow |
| 7 | +import mlflow.spark |
| 8 | +import numpy as np |
| 9 | +from pyspark.sql import Row |
| 10 | + |
| 11 | +""" |
| 12 | +Some utilities for use in Zeppelin when doing machine learning |
| 13 | +""" |
| 14 | + |
| 15 | + |
| 16 | +class Run: |
| 17 | + """ |
| 18 | + An abstraction over MLFlow Runs, allowing you to do cross cell runs |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__(self): |
| 22 | + self.run_uuid = None |
| 23 | + |
| 24 | + @staticmethod |
| 25 | + def handle_handlers(handler, *args, **kwargs): |
| 26 | + if handler == 'param': |
| 27 | + mlflow.log_param(*args, **kwargs) |
| 28 | + elif handler == 'metric': |
| 29 | + mlflow.log_metric(*args, **kwargs) |
| 30 | + elif handler == 'artifact': |
| 31 | + mlflow.log_artifact(*args, **kwargs) |
| 32 | + elif handler == 'spark_model': |
| 33 | + mlflow.spark.log_model(*args, **kwargs) |
| 34 | + else: |
| 35 | + raise Exception( |
| 36 | + "Handler {0} not understood. Please use one in ['param', 'metric', " |
| 37 | + "'artifact', 'spark_model']") |
| 38 | + |
| 39 | + def log_metadata(self, handler, *args, **kwargs): |
| 40 | + if not self.run_uuid: |
| 41 | + with mlflow.start_run(): |
| 42 | + self.run_uuid = (mlflow.active_run().__dict__['_info'].__dict__['_run_uuid']) |
| 43 | + print("Logged using handler " + handler) |
| 44 | + Run.handle_handlers(handler, *args, **kwargs) |
| 45 | + else: |
| 46 | + with mlflow.start_run(run_uuid=self.run_uuid): |
| 47 | + Run.handle_handlers(handler, *args, **kwargs) |
| 48 | + print("Logged using handler " + handler) |
| 49 | + return True |
| 50 | + |
| 51 | + def log_param(self, *args, **kwargs): |
| 52 | + return self.log_metadata('param', *args, **kwargs) |
| 53 | + |
| 54 | + def log_metric(self, *args, **kwargs): |
| 55 | + return self.log_metadata('metric', *args, **kwargs) |
| 56 | + |
| 57 | + def log_artifact(self, *args, **kwargs): |
| 58 | + return self.log_metadata('artifact', *args, **kwargs) |
| 59 | + |
| 60 | + def log_model(self, *args, **kwargs): |
| 61 | + return self.log_metadata('spark_model', *args, **kwargs) |
| 62 | + |
| 63 | + def create_new_run(self): |
| 64 | + """ |
| 65 | + Create a new Run |
| 66 | + :return: |
| 67 | + """ |
| 68 | + self.run_uuid = None |
| 69 | + |
| 70 | + |
| 71 | +def show_confusion_matrix(TP, TN, FP, FN): |
| 72 | + """ |
| 73 | + function that shows you a device called a confusion matrix... will be helpful when evaluating. It allows you to see how well your model performs |
| 74 | + :param TP: True Positives |
| 75 | + :param TN: True Negatives |
| 76 | + :param FP: False Positives |
| 77 | + :param FN: False Negatives |
| 78 | + """ |
| 79 | + confusion_rdd = sc.parallelize([['Predicted', TP, FN], ['Actual', FP, TN]]) |
| 80 | + confusion_matrix = sqlContext.createDataFrame(confusion_rdd, ['', 'Actual', 'Predicted']) |
| 81 | + confusion_matrix.show() |
| 82 | + |
| 83 | + |
| 84 | +def experiment_maker(experiment_id): |
| 85 | + """ |
| 86 | + a function that creates a new experiment if "experiment_name" doesn't exist |
| 87 | + or will use the current one if it already does |
| 88 | + :param experiment_id the experiment name you would like to get or create |
| 89 | + """ |
| 90 | + print("Tracking Path " + mlflow.get_tracking_uri()) |
| 91 | + found = False |
| 92 | + if not len(experiment_id) in [0, 1]: |
| 93 | + for e in [i for i in mlflow.tracking.list_experiments()]: # Check all experiments |
| 94 | + if experiment_id == e.name: |
| 95 | + print('Experiment has already been created') |
| 96 | + found = True |
| 97 | + os.environ['MLFLOW_EXPERIMENT_ID'] = str( |
| 98 | + e._experiment_id) # use already created experiment |
| 99 | + |
| 100 | + if not found: |
| 101 | + _id = mlflow.tracking.create_experiment(experiment_id) # create new experiment |
| 102 | + print('Success! Created Experiment') |
| 103 | + os.environ['MLFLOW_EXPERIMENT_ID'] = str(_id) # use it |
| 104 | + else: |
| 105 | + print("Please fill out this field") |
| 106 | + |
| 107 | + |
| 108 | +class ModelEvaluator(object): |
| 109 | + """ |
| 110 | + A Function that provides an easy way to evaluate models once, or over random iterations |
| 111 | + """ |
| 112 | + |
| 113 | + def __init__(self, label_column='label', prediction_column='prediction', confusion_matrix=True): |
| 114 | + """ |
| 115 | + :param label_column: the column in the dataframe containing the correct output |
| 116 | + :param prediction_column: the column in the dataframe containing the prediction |
| 117 | + :param confusion_matrix: whether or not to show a confusion matrix after each input |
| 118 | + """ |
| 119 | + self.avg_tp = [] |
| 120 | + self.avg_tn = [] |
| 121 | + self.avg_fn = [] |
| 122 | + self.avg_fp = [] |
| 123 | + |
| 124 | + self.label_column = label_column |
| 125 | + self.prediction_column = prediction_column |
| 126 | + self.confusion_matrix = confusion_matrix |
| 127 | + |
| 128 | + def input(self, predictions_dataframe): |
| 129 | + """ |
| 130 | + Evaluate actual vs Predicted in a dataframe |
| 131 | + :param predictions_dataframe: the dataframe containing the label and the predicition |
| 132 | + """ |
| 133 | + |
| 134 | + pred_v_lab = predictions_dataframe.select(self.label_column, |
| 135 | + self.prediction_column) # Select the actual and the predicted labels |
| 136 | + |
| 137 | + self.avg_tp.append(pred_v_lab[(pred_v_lab.label == 1) & ( |
| 138 | + pred_v_lab.prediction == 1)].count()) # Add confusion stats |
| 139 | + self.avg_tn.append( |
| 140 | + pred_v_lab[(pred_v_lab.label == 0) & (pred_v_lab.prediction == 0)].count()) |
| 141 | + self.avg_fp.append( |
| 142 | + pred_v_lab[(pred_v_lab.label == 1) & (pred_v_lab.prediction == 0)].count()) |
| 143 | + self.avg_fn.append( |
| 144 | + pred_v_lab[(pred_v_lab.label == 0) & (pred_v_lab.prediction == 1)].count()) |
| 145 | + |
| 146 | + if self.confusion_matrix: |
| 147 | + show_confusion_matrix(self.avg_tp[-1], self.avg_tn[-1], self.avg_fp[-1], |
| 148 | + self.avg_fn[-1]) # show the confusion matrix to the user |
| 149 | + |
| 150 | + def get_results(self, output_type='dataframe'): |
| 151 | + """ |
| 152 | + Return a dictionary containing evaluated results |
| 153 | + :param output_type: either a dataframe or a dict (which to return) |
| 154 | + :return results: computed_metrics (dict) or computed_df (df) |
| 155 | + """ |
| 156 | + |
| 157 | + TP = np.mean(self.avg_tp) |
| 158 | + TN = np.mean(self.avg_tn) |
| 159 | + FP = np.mean(self.avg_fp) |
| 160 | + FN = np.mean(self.avg_fn) |
| 161 | + |
| 162 | + computed_metrics = { |
| 163 | + 'TPR': float(TP) / (TP + FN), |
| 164 | + 'SPC': float(TP) / (TP + FN), |
| 165 | + 'PPV': float(TP) / (TP + FP), |
| 166 | + "NPV": float(TN) / (TN + FN), |
| 167 | + "FPR": float(FP) / (FP + TN), |
| 168 | + "FDR": float(FP) / (FP + TP), |
| 169 | + "FNR": float(FN) / (FN + TP), |
| 170 | + "ACC": float(TP + TN) / (TP + FN + FP + TN), |
| 171 | + "F1": float((2 * TP)) / ((2 * TP) + FP + FN), |
| 172 | + "MCC": float((TP * TN) - (FP * FN)) / ( |
| 173 | + np.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))) |
| 174 | + } |
| 175 | + |
| 176 | + if output_type == 'dict': |
| 177 | + return computed_metrics |
| 178 | + |
| 179 | + else: |
| 180 | + metrics_row = Row('TPR', 'SPC', 'PPV', 'NPV', 'FPR', 'FDR', 'FNR', 'ACC', 'F1', 'MCC') |
| 181 | + computed_row = metrics_row(*computed_metrics.values()) |
| 182 | + computed_df = sqlContext.createDataFrame([computed_row]) |
| 183 | + return computed_df |
| 184 | + |
| 185 | + |
| 186 | +def print_horizantal_line(l): |
| 187 | + print("".join(['-' * l])) |
| 188 | + |
| 189 | + |
| 190 | +def display(html): |
| 191 | + print("%angular") |
| 192 | + print(html) |
| 193 | + |
| 194 | + |
| 195 | +class DecisionTreeVisualizer(object): |
| 196 | + """ |
| 197 | + Visualize a decision tree, either in code like format, or graphviz |
| 198 | + """ |
| 199 | + |
| 200 | + @staticmethod |
| 201 | + def visualize(model, feature_column_names, label_names, tree_name, visual=True): |
| 202 | + """ |
| 203 | + Visualize a decision tree, either in a code like format, or graphviz |
| 204 | + :param model: the fitted decision tree classifier |
| 205 | + :param feature_column_names: column names for features |
| 206 | + :param label_names: labels vector (below avg, above avg) |
| 207 | + :param tree_name: the name you would like to call the tree |
| 208 | + :param visual: bool, true if you want a graphviz pdf containing your file |
| 209 | + :return: none |
| 210 | + """ |
| 211 | + |
| 212 | + tree_to_json = DecisionTreeVisualizer.replacer(model.toDebugString, |
| 213 | + ['feature ' + str(i) for i in |
| 214 | + range(0, len(feature_column_names))], |
| 215 | + feature_column_names) |
| 216 | + tree_to_json = DecisionTreeVisualizer.replacer(tree_to_json, |
| 217 | + ['Predict ' + str(i) + '.0' for i in |
| 218 | + range(0, len(label_names))], label_names) |
| 219 | + if not visual: |
| 220 | + return tree_to_json |
| 221 | + |
| 222 | + dot = graphviz.Digraph(comment='Decision Tree') |
| 223 | + dot.attr(size="7.75,15.25") |
| 224 | + dot.node_attr.update(color='lightblue2', style='filled') |
| 225 | + json_d = DecisionTreeVisualizer.tree_json(tree_to_json) |
| 226 | + dot.format = 'pdf' |
| 227 | + |
| 228 | + DecisionTreeVisualizer.add_node(dot, '', '', json_d, realroot=True) |
| 229 | + dot.render('/zeppelin/webapps/webapp/assets/images/' + tree_name) |
| 230 | + print('Successfully uploaded file to Zeppelin Assests on this cluster') |
| 231 | + print('Uploading.') |
| 232 | + |
| 233 | + time.sleep(3) |
| 234 | + print('Uploading..') |
| 235 | + time.sleep(3) |
| 236 | + |
| 237 | + print('You can find your visualization at "https://docs.google.com/gview?url=https' |
| 238 | + '://<cluster_name>.splicemachine.io/assets/images/' + tree_name + '.pdf&embedded=tru' |
| 239 | + 'e#view=fith') |
| 240 | + |
| 241 | + @staticmethod |
| 242 | + def replacer(string, bad, good): |
| 243 | + """ |
| 244 | + Replace every string in "bad" with the corresponding string in "good" |
| 245 | + :param string: string to replace in |
| 246 | + :param bad: array of strings to replace |
| 247 | + :param good: array of strings to replace with |
| 248 | + :return: |
| 249 | + """ |
| 250 | + for b, g in zip(bad, good): |
| 251 | + string = string.replace(b, g) |
| 252 | + return string |
| 253 | + |
| 254 | + @staticmethod |
| 255 | + def add_node(dot, parent, node_hash, root, realroot=False): |
| 256 | + """ |
| 257 | + Traverse through the .debugString json and generate a graphviz tree |
| 258 | + :param dot: dot file objevt |
| 259 | + :param parent: not used currently |
| 260 | + :param node_hash: unique node id |
| 261 | + :param root: the root of tree |
| 262 | + :param realroot: whether or not it is the real root, or a recursive root |
| 263 | + :return: |
| 264 | + """ |
| 265 | + node_id = str(hash(root['name'])) + str(random.randint(0, 100)) |
| 266 | + if root: |
| 267 | + dot.node(node_id, root['name']) |
| 268 | + if not realroot: |
| 269 | + dot.edge(node_hash, node_id) |
| 270 | + if root.get('children'): |
| 271 | + if not root['children'][0].get('children'): |
| 272 | + DecisionTreeVisualizer.add_node(dot, root['name'], node_id, root['children'][0]) |
| 273 | + else: |
| 274 | + DecisionTreeVisualizer.add_node(dot, root['name'], node_id, root['children'][0]) |
| 275 | + DecisionTreeVisualizer.add_node(dot, root['name'], node_id, root['children'][1]) |
| 276 | + |
| 277 | + @staticmethod |
| 278 | + def parse(lines): |
| 279 | + """ |
| 280 | + Lines in debug string |
| 281 | + :param lines: |
| 282 | + :return: block json |
| 283 | + """ |
| 284 | + block = [] |
| 285 | + while lines: |
| 286 | + |
| 287 | + if lines[0].startswith('If'): |
| 288 | + bl = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '') |
| 289 | + block.append({'name': bl, 'children': DecisionTreeVisualizer.parse(lines)}) |
| 290 | + |
| 291 | + if lines[0].startswith('Else'): |
| 292 | + be = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '') |
| 293 | + block.append({'name': be, 'children': DecisionTreeVisualizer.parse(lines)}) |
| 294 | + elif not lines[0].startswith(('If', 'Else')): |
| 295 | + block2 = lines.pop(0) |
| 296 | + block.append({'name': block2}) |
| 297 | + else: |
| 298 | + break |
| 299 | + return block |
| 300 | + |
| 301 | + @staticmethod |
| 302 | + def tree_json(tree): |
| 303 | + """ |
| 304 | + Generate a JSON representation of a decision tree |
| 305 | + :param tree: tree debug string |
| 306 | + :return: json |
| 307 | + """ |
| 308 | + data = [] |
| 309 | + for line in tree.splitlines(): |
| 310 | + if line.strip(): |
| 311 | + line = line.strip() |
| 312 | + data.append(line) |
| 313 | + else: |
| 314 | + break |
| 315 | + if not line: |
| 316 | + break |
| 317 | + res = [{'name': 'Root', 'children': DecisionTreeVisualizer.parse(data[1:])}] |
| 318 | + return res[0] |
0 commit comments