Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit a81a69f

Browse files
Amrit BavejaAmrit Baveja
authored andcommitted
BREAKING CHANGES: restructured and added ml zeppelin code
1 parent c33b3f8 commit a81a69f

File tree

8 files changed

+335
-17
lines changed

8 files changed

+335
-17
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17-
from setuptools import setup
17+
from setuptools import setup, find_packages
1818

1919
dependencies = [
2020
"atomicwrites==1.1.5",
@@ -31,5 +31,5 @@
3131
name="splicemachine",
3232
version="0.2.2",
3333
install_requires=dependencies,
34-
packages=['splicemachine'],
34+
packages=find_packages(),
3535
)

splicemachine/__init__.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +0,0 @@
1-
"""
2-
Copyright 2018 Splice Machine, Inc.
3-
4-
Licensed under the Apache License, Version 2.0 (the "License");
5-
you may not use this file except in compliance with the License.
6-
You may obtain a copy of the License at
7-
8-
http://www.apache.org/licenses/LICENSE-2.0
9-
10-
Unless required by applicable law or agreed to in writing, software
11-
distributed under the License is distributed on an "AS IS" BASIS,
12-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
See the License for the specific language governing permissions and
14-
limitations under the License.
15-
"""

splicemachine/ml/__init__.py

Whitespace-only changes.

splicemachine/ml/zeppelin.py

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
import os
2+
import random
3+
import time
4+
5+
import graphviz
6+
import mlflow
7+
import mlflow.spark
8+
import numpy as np
9+
from pyspark.sql import Row
10+
11+
"""
12+
Some utilities for use in Zeppelin when doing machine learning
13+
"""
14+
15+
16+
class Run:
17+
"""
18+
An abstraction over MLFlow Runs, allowing you to do cross cell runs
19+
"""
20+
21+
def __init__(self):
22+
self.run_uuid = None
23+
24+
@staticmethod
25+
def handle_handlers(handler, *args, **kwargs):
26+
if handler == 'param':
27+
mlflow.log_param(*args, **kwargs)
28+
elif handler == 'metric':
29+
mlflow.log_metric(*args, **kwargs)
30+
elif handler == 'artifact':
31+
mlflow.log_artifact(*args, **kwargs)
32+
elif handler == 'spark_model':
33+
mlflow.spark.log_model(*args, **kwargs)
34+
else:
35+
raise Exception(
36+
"Handler {0} not understood. Please use one in ['param', 'metric', "
37+
"'artifact', 'spark_model']")
38+
39+
def log_metadata(self, handler, *args, **kwargs):
40+
if not self.run_uuid:
41+
with mlflow.start_run():
42+
self.run_uuid = (mlflow.active_run().__dict__['_info'].__dict__['_run_uuid'])
43+
print("Logged using handler " + handler)
44+
Run.handle_handlers(handler, *args, **kwargs)
45+
else:
46+
with mlflow.start_run(run_uuid=self.run_uuid):
47+
Run.handle_handlers(handler, *args, **kwargs)
48+
print("Logged using handler " + handler)
49+
return True
50+
51+
def log_param(self, *args, **kwargs):
52+
return self.log_metadata('param', *args, **kwargs)
53+
54+
def log_metric(self, *args, **kwargs):
55+
return self.log_metadata('metric', *args, **kwargs)
56+
57+
def log_artifact(self, *args, **kwargs):
58+
return self.log_metadata('artifact', *args, **kwargs)
59+
60+
def log_model(self, *args, **kwargs):
61+
return self.log_metadata('spark_model', *args, **kwargs)
62+
63+
def create_new_run(self):
64+
"""
65+
Create a new Run
66+
:return:
67+
"""
68+
self.run_uuid = None
69+
70+
71+
def show_confusion_matrix(TP, TN, FP, FN):
72+
"""
73+
function that shows you a device called a confusion matrix... will be helpful when evaluating. It allows you to see how well your model performs
74+
:param TP: True Positives
75+
:param TN: True Negatives
76+
:param FP: False Positives
77+
:param FN: False Negatives
78+
"""
79+
confusion_rdd = sc.parallelize([['Predicted', TP, FN], ['Actual', FP, TN]])
80+
confusion_matrix = sqlContext.createDataFrame(confusion_rdd, ['', 'Actual', 'Predicted'])
81+
confusion_matrix.show()
82+
83+
84+
def experiment_maker(experiment_id):
85+
"""
86+
a function that creates a new experiment if "experiment_name" doesn't exist
87+
or will use the current one if it already does
88+
:param experiment_id the experiment name you would like to get or create
89+
"""
90+
print("Tracking Path " + mlflow.get_tracking_uri())
91+
found = False
92+
if not len(experiment_id) in [0, 1]:
93+
for e in [i for i in mlflow.tracking.list_experiments()]: # Check all experiments
94+
if experiment_id == e.name:
95+
print('Experiment has already been created')
96+
found = True
97+
os.environ['MLFLOW_EXPERIMENT_ID'] = str(
98+
e._experiment_id) # use already created experiment
99+
100+
if not found:
101+
_id = mlflow.tracking.create_experiment(experiment_id) # create new experiment
102+
print('Success! Created Experiment')
103+
os.environ['MLFLOW_EXPERIMENT_ID'] = str(_id) # use it
104+
else:
105+
print("Please fill out this field")
106+
107+
108+
class ModelEvaluator(object):
109+
"""
110+
A Function that provides an easy way to evaluate models once, or over random iterations
111+
"""
112+
113+
def __init__(self, label_column='label', prediction_column='prediction', confusion_matrix=True):
114+
"""
115+
:param label_column: the column in the dataframe containing the correct output
116+
:param prediction_column: the column in the dataframe containing the prediction
117+
:param confusion_matrix: whether or not to show a confusion matrix after each input
118+
"""
119+
self.avg_tp = []
120+
self.avg_tn = []
121+
self.avg_fn = []
122+
self.avg_fp = []
123+
124+
self.label_column = label_column
125+
self.prediction_column = prediction_column
126+
self.confusion_matrix = confusion_matrix
127+
128+
def input(self, predictions_dataframe):
129+
"""
130+
Evaluate actual vs Predicted in a dataframe
131+
:param predictions_dataframe: the dataframe containing the label and the predicition
132+
"""
133+
134+
pred_v_lab = predictions_dataframe.select(self.label_column,
135+
self.prediction_column) # Select the actual and the predicted labels
136+
137+
self.avg_tp.append(pred_v_lab[(pred_v_lab.label == 1) & (
138+
pred_v_lab.prediction == 1)].count()) # Add confusion stats
139+
self.avg_tn.append(
140+
pred_v_lab[(pred_v_lab.label == 0) & (pred_v_lab.prediction == 0)].count())
141+
self.avg_fp.append(
142+
pred_v_lab[(pred_v_lab.label == 1) & (pred_v_lab.prediction == 0)].count())
143+
self.avg_fn.append(
144+
pred_v_lab[(pred_v_lab.label == 0) & (pred_v_lab.prediction == 1)].count())
145+
146+
if self.confusion_matrix:
147+
show_confusion_matrix(self.avg_tp[-1], self.avg_tn[-1], self.avg_fp[-1],
148+
self.avg_fn[-1]) # show the confusion matrix to the user
149+
150+
def get_results(self, output_type='dataframe'):
151+
"""
152+
Return a dictionary containing evaluated results
153+
:param output_type: either a dataframe or a dict (which to return)
154+
:return results: computed_metrics (dict) or computed_df (df)
155+
"""
156+
157+
TP = np.mean(self.avg_tp)
158+
TN = np.mean(self.avg_tn)
159+
FP = np.mean(self.avg_fp)
160+
FN = np.mean(self.avg_fn)
161+
162+
computed_metrics = {
163+
'TPR': float(TP) / (TP + FN),
164+
'SPC': float(TP) / (TP + FN),
165+
'PPV': float(TP) / (TP + FP),
166+
"NPV": float(TN) / (TN + FN),
167+
"FPR": float(FP) / (FP + TN),
168+
"FDR": float(FP) / (FP + TP),
169+
"FNR": float(FN) / (FN + TP),
170+
"ACC": float(TP + TN) / (TP + FN + FP + TN),
171+
"F1": float((2 * TP)) / ((2 * TP) + FP + FN),
172+
"MCC": float((TP * TN) - (FP * FN)) / (
173+
np.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)))
174+
}
175+
176+
if output_type == 'dict':
177+
return computed_metrics
178+
179+
else:
180+
metrics_row = Row('TPR', 'SPC', 'PPV', 'NPV', 'FPR', 'FDR', 'FNR', 'ACC', 'F1', 'MCC')
181+
computed_row = metrics_row(*computed_metrics.values())
182+
computed_df = sqlContext.createDataFrame([computed_row])
183+
return computed_df
184+
185+
186+
def print_horizantal_line(l):
187+
print("".join(['-' * l]))
188+
189+
190+
def display(html):
191+
print("%angular")
192+
print(html)
193+
194+
195+
class DecisionTreeVisualizer(object):
196+
"""
197+
Visualize a decision tree, either in code like format, or graphviz
198+
"""
199+
200+
@staticmethod
201+
def visualize(model, feature_column_names, label_names, tree_name, visual=True):
202+
"""
203+
Visualize a decision tree, either in a code like format, or graphviz
204+
:param model: the fitted decision tree classifier
205+
:param feature_column_names: column names for features
206+
:param label_names: labels vector (below avg, above avg)
207+
:param tree_name: the name you would like to call the tree
208+
:param visual: bool, true if you want a graphviz pdf containing your file
209+
:return: none
210+
"""
211+
212+
tree_to_json = DecisionTreeVisualizer.replacer(model.toDebugString,
213+
['feature ' + str(i) for i in
214+
range(0, len(feature_column_names))],
215+
feature_column_names)
216+
tree_to_json = DecisionTreeVisualizer.replacer(tree_to_json,
217+
['Predict ' + str(i) + '.0' for i in
218+
range(0, len(label_names))], label_names)
219+
if not visual:
220+
return tree_to_json
221+
222+
dot = graphviz.Digraph(comment='Decision Tree')
223+
dot.attr(size="7.75,15.25")
224+
dot.node_attr.update(color='lightblue2', style='filled')
225+
json_d = DecisionTreeVisualizer.tree_json(tree_to_json)
226+
dot.format = 'pdf'
227+
228+
DecisionTreeVisualizer.add_node(dot, '', '', json_d, realroot=True)
229+
dot.render('/zeppelin/webapps/webapp/assets/images/' + tree_name)
230+
print('Successfully uploaded file to Zeppelin Assests on this cluster')
231+
print('Uploading.')
232+
233+
time.sleep(3)
234+
print('Uploading..')
235+
time.sleep(3)
236+
237+
print('You can find your visualization at "https://docs.google.com/gview?url=https'
238+
'://<cluster_name>.splicemachine.io/assets/images/' + tree_name + '.pdf&embedded=tru'
239+
'e#view=fith')
240+
241+
@staticmethod
242+
def replacer(string, bad, good):
243+
"""
244+
Replace every string in "bad" with the corresponding string in "good"
245+
:param string: string to replace in
246+
:param bad: array of strings to replace
247+
:param good: array of strings to replace with
248+
:return:
249+
"""
250+
for b, g in zip(bad, good):
251+
string = string.replace(b, g)
252+
return string
253+
254+
@staticmethod
255+
def add_node(dot, parent, node_hash, root, realroot=False):
256+
"""
257+
Traverse through the .debugString json and generate a graphviz tree
258+
:param dot: dot file objevt
259+
:param parent: not used currently
260+
:param node_hash: unique node id
261+
:param root: the root of tree
262+
:param realroot: whether or not it is the real root, or a recursive root
263+
:return:
264+
"""
265+
node_id = str(hash(root['name'])) + str(random.randint(0, 100))
266+
if root:
267+
dot.node(node_id, root['name'])
268+
if not realroot:
269+
dot.edge(node_hash, node_id)
270+
if root.get('children'):
271+
if not root['children'][0].get('children'):
272+
DecisionTreeVisualizer.add_node(dot, root['name'], node_id, root['children'][0])
273+
else:
274+
DecisionTreeVisualizer.add_node(dot, root['name'], node_id, root['children'][0])
275+
DecisionTreeVisualizer.add_node(dot, root['name'], node_id, root['children'][1])
276+
277+
@staticmethod
278+
def parse(lines):
279+
"""
280+
Lines in debug string
281+
:param lines:
282+
:return: block json
283+
"""
284+
block = []
285+
while lines:
286+
287+
if lines[0].startswith('If'):
288+
bl = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '')
289+
block.append({'name': bl, 'children': DecisionTreeVisualizer.parse(lines)})
290+
291+
if lines[0].startswith('Else'):
292+
be = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '')
293+
block.append({'name': be, 'children': DecisionTreeVisualizer.parse(lines)})
294+
elif not lines[0].startswith(('If', 'Else')):
295+
block2 = lines.pop(0)
296+
block.append({'name': block2})
297+
else:
298+
break
299+
return block
300+
301+
@staticmethod
302+
def tree_json(tree):
303+
"""
304+
Generate a JSON representation of a decision tree
305+
:param tree: tree debug string
306+
:return: json
307+
"""
308+
data = []
309+
for line in tree.splitlines():
310+
if line.strip():
311+
line = line.strip()
312+
data.append(line)
313+
else:
314+
break
315+
if not line:
316+
break
317+
res = [{'name': 'Root', 'children': DecisionTreeVisualizer.parse(data[1:])}]
318+
return res[0]

splicemachine/spark/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2018 Splice Machine, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)