Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit 22a52fa

Browse files
author
Epstein
authored
Dbaas 3809 (#51)
* support for saving/loading sklearn and keras models * wrong parameter order * missing read mode for keras file * get run ids by name * extra quote
1 parent e3295c1 commit 22a52fa

File tree

3 files changed

+127
-29
lines changed

3 files changed

+127
-29
lines changed

splicemachine/mlflow_support/constants.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,24 @@ class SparkModelType(Enum):
2323
REGRESSION = 1
2424
CLUSTERING_WITH_PROB = 2
2525
CLUSTERING_WO_PROB = 3
26+
27+
class FileExtensions():
28+
"""
29+
Class containing names for
30+
valid File Extensions
31+
"""
32+
spark: str = "spark"
33+
keras: str = "h5"
34+
h2o: str = "h2o"
35+
sklearn: str = "pkl"
36+
37+
@staticmethod
38+
def get_valid() -> tuple:
39+
"""
40+
Return a tuple of the valid file extensions
41+
in Database
42+
:return: (tuple) valid statuses
43+
"""
44+
return (
45+
FileExtensions.spark, FileExtensions.keras, FileExtensions.h2o, FileExtensions.sklearn
46+
)

splicemachine/mlflow_support/mlflow_support.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
import time
22
from collections import defaultdict
33
from contextlib import contextmanager
4-
from io import BytesIO
5-
from os import path, remove
6-
from shutil import rmtree
7-
from zipfile import ZipFile
4+
from os import path
85
from sys import version as py_version
96

107
import gorilla
118
import mlflow
129
import requests
1310
from requests.auth import HTTPBasicAuth
1411
from mleap.pyspark import spark_support
15-
import h2o
1612
import pyspark
13+
from pyspark.ml.base import Estimator as SparkModel
14+
import sklearn
15+
from sklearn.base import BaseEstimator as ScikitModel
16+
from tensorflow import __version__ as tf_version
17+
from tensorflow.keras import __version__ as keras_version
18+
from tensorflow.keras import Model as KerasModel
1719

1820
from splicemachine.mlflow_support.constants import *
1921
from splicemachine.mlflow_support.utilities import *
@@ -24,6 +26,7 @@
2426
_TRACKING_URL = get_pod_uri("mlflow", "5001", _TESTING)
2527

2628
_CLIENT = mlflow.tracking.MlflowClient(tracking_uri=_TRACKING_URL)
29+
mlflow.client = _CLIENT
2730

2831
_GORILLA_SETTINGS = gorilla.Settings(allow_hit=True, store_hit=True)
2932
_PYTHON_VERSION = py_version.split('|')[0].strip()
@@ -50,6 +53,23 @@ def _get_current_run_data():
5053
return _CLIENT.get_run(mlflow.active_run().info.run_id).data
5154

5255

56+
@_mlflow_patch('get_run_ids_by_name')
57+
def _get_run_ids_by_name(run_name, experiment_id=None):
58+
"""
59+
Gets a run id from the run name. If there are multiple runs with the same name, all run IDs are returned
60+
:param run_name: The name of the run
61+
:param experiment_id: The experiment to search in. If None, all experiments are searched
62+
:return: List of run ids
63+
"""
64+
exps = [experiment_id] if experiment_id else _CLIENT.list_experiments()
65+
run_ids = []
66+
for exp in exps:
67+
for run in _CLIENT.search_runs(exp.experiment_id):
68+
if run_name == run.data.tags['mlflow.runName']:
69+
run_ids.append(run.data.tags['Run ID'])
70+
return run_ids
71+
72+
5373
@_mlflow_patch('register_splice_context')
5474
def _register_splice_context(splice_context):
5575
"""
@@ -70,7 +90,7 @@ def _check_for_splice_ctx():
7090

7191
if not hasattr(mlflow, '_splice_context'):
7292
raise SpliceMachineException(
73-
"You must run `mlflow.register_splice_context(py_splice_context) before "
93+
"You must run `mlflow.register_splice_context(pysplice_context) before "
7494
"you can run this mlflow operation!"
7595
)
7696

@@ -135,17 +155,26 @@ def _log_model(model, name='model'):
135155
run_id = mlflow.active_run().info.run_uuid
136156
if 'h2o' in model_class.lower():
137157
mlflow.set_tag('splice.h2o_version', h2o.__version__)
138-
model_path = h2o.save_model(model=model, path='/tmp/model', force=True)
139-
with open(model_path, 'rb') as artifact:
140-
byte_stream = bytearray(bytes(artifact.read()))
141-
insert_artifact(mlflow._splice_context, name, byte_stream, run_id, file_ext='h2omodel')
142-
rmtree('/tmp/model')
158+
H2OUtils.log_h2o_model(mlflow._splice_context, model, name, run_id)
143159

144-
elif 'spark' in model_class.lower():
160+
elif isinstance(model, SparkModel):
145161
mlflow.set_tag('splice.spark_version', pyspark.__version__)
146-
SparkUtils.log_spark_model(mlflow._splice_context, model, name, run_id=run_id)
162+
SparkUtils.log_spark_model(mlflow._splice_context, model, name, run_id)
163+
164+
elif isinstance(model, ScikitModel):
165+
mlflow.set_tag('splice.sklearn_version', sklearn.__version__)
166+
SKUtils.log_sklearn_model(mlflow._splice_context, model, name, run_id)
167+
168+
elif isinstance(model, KerasModel): # We can't handle keras models with a different backend
169+
mlflow.set_tag('splice.keras_version', keras_version)
170+
mlflow.set_tag('splice.tf_version', tf_version)
171+
KerasUtils.log_keras_model(mlflow._splice_context, model, name, run_id)
172+
173+
147174
else:
148-
raise SpliceMachineException('Currently we only support logging Spark and H2O models.')
175+
raise SpliceMachineException('Model type not supported for logging.'
176+
'Currently we support logging Spark, H2O, SKLearn and Keras (TF backend) models.'
177+
'You can save your model to disk, zip it and run mlflow.log_artifact to save.')
149178

150179
@_mlflow_patch('start_run')
151180
def _start_run(run_id=None, tags=None, experiment_id=None, run_name=None, nested=False):
@@ -313,14 +342,18 @@ def _load_model(run_id=None, name='model'):
313342
run_id = run_id or mlflow.active_run().info.run_uuid
314343
model_blob, file_ext = SparkUtils.retrieve_artifact_stream(mlflow._splice_context, run_id, name)
315344

316-
if file_ext == 'sparkmodel':
345+
if file_ext == FileExtensions.spark:
317346
model = SparkUtils.load_spark_model(mlflow._splice_context, model_blob)
347+
elif file_ext == FileExtensions.h2o:
348+
model = H2OUtils.load_h2o_model(model_blob)
349+
elif file_ext == FileExtensions.sklearn:
350+
model = SKUtils.load_sklearn_model(model_blob)
351+
elif file_ext == FileExtensions.keras:
352+
model = KerasUtils.load_keras_model(model_blob)
353+
else:
354+
raise SpliceMachineException(f'Model extension {file_ext} was not a supported model type. '
355+
f'Supported model extensions are {FileExtensions.get_valid()}')
318356

319-
elif file_ext == 'h2omodel':
320-
with open('/tmp/model', 'wb') as file:
321-
file.write(model_blob)
322-
model = h2o.load_model('/tmp/model')
323-
remove('/tmp/model')
324357
return model
325358

326359

@@ -585,7 +618,8 @@ def apply_patches():
585618
"""
586619
targets = [_register_splice_context, _lp, _lm, _timer, _log_artifact, _log_feature_transformations,
587620
_log_model_params, _log_pipeline_stages, _log_model, _load_model, _download_artifact,
588-
_start_run, _current_run_id, _current_exp_id, _deploy_aws, _deploy_azure, _deploy_db, _login_director]
621+
_start_run, _current_run_id, _current_exp_id, _deploy_aws, _deploy_azure, _deploy_db, _login_director,
622+
_get_run_ids_by_name]
589623

590624
for target in targets:
591625
gorilla.apply(gorilla.Patch(mlflow, target.__name__.lstrip('_'), target, settings=_GORILLA_SETTINGS))

splicemachine/mlflow_support/utilities.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
from os import environ as env_vars, popen as rbash, system as bash, remove
22
from sys import getsizeof
3+
from shutil import rmtree
4+
from pickle import dumps as save_pickle_string, loads as load_pickle_string
5+
from io import BytesIO
6+
from h5py import File as h5_file
37
import re
48

5-
from pyspark.ml import Pipeline, PipelineModel
69
from pyspark.ml.base import Model as SparkModel
10+
from tensorflow.keras.models import load_model as load_kr_model
711
from py4j.java_gateway import java_import
812

913
from splicemachine.spark.constants import SQL_TYPES
1014
from splicemachine.mlflow_support.constants import *
1115
from mleap.pyspark.spark_support import SimpleSparkSerializer
16+
1217
import h2o
1318

1419
from pyspark.ml.pipeline import PipelineModel
@@ -102,8 +107,25 @@ def get_h2omojo_model(splice_context, model):
102107
raw_mojo = jvm.MojoModel.load(model_path)
103108
java_mojo_c = jvm.EasyPredictModelWrapper.Config().setModel(raw_mojo)
104109
java_mojo = jvm.EasyPredictModelWrapper(java_mojo_c)
110+
remove('/tmp/model.zip')
105111
return java_mojo, raw_mojo
106112

113+
@staticmethod
114+
def log_h2o_model(splice_context, model, name, run_id):
115+
model_path = h2o.save_model(model=model, path='/tmp/model', force=True)
116+
with open(model_path, 'rb') as artifact:
117+
byte_stream = bytearray(bytes(artifact.read()))
118+
insert_artifact(splice_context, name, byte_stream, run_id, file_ext=FileExtensions.h2o)
119+
rmtree('/tmp/model')
120+
121+
@staticmethod
122+
def load_h2o_model(model_blob):
123+
with open('/tmp/model', 'wb') as file:
124+
file.write(model_blob)
125+
model = h2o.load_model('/tmp/model')
126+
remove('/tmp/model')
127+
return model
128+
107129
@staticmethod
108130
def insert_h2omojo_model(splice_context, run_id, model):
109131
model_exists = splice_context.df(
@@ -121,7 +143,29 @@ def insert_h2omojo_model(splice_context, run_id, model):
121143
insert_model(splice_context, run_id, byte_array, 'h2omojo', h2o.__version__)
122144

123145

146+
class SKUtils:
147+
@staticmethod
148+
def log_sklearn_model(splice_context, model, name, run_id):
149+
byte_stream = save_pickle_string(model)
150+
insert_artifact(splice_context, name, byte_stream, run_id, file_ext=FileExtensions.sklearn)
124151

152+
@staticmethod
153+
def load_sklearn_model(model_blob):
154+
return load_pickle_string(model_blob)
155+
156+
class KerasUtils:
157+
@staticmethod
158+
def log_keras_model(splice_context, model, name, run_id):
159+
model.save('/tmp/model.h5')
160+
with open('/tmp/model.h5', 'rb') as f:
161+
byte_stream = bytearray(bytes(f.read()))
162+
insert_artifact(splice_context, name, byte_stream, run_id, file_ext=FileExtensions.keras)
163+
remove('/tmp/model.h5')
164+
165+
@staticmethod
166+
def load_keras_model(model_blob):
167+
hfile = h5_file(BytesIO(model_blob), 'r')
168+
return load_kr_model(hfile)
125169

126170
class SparkUtils:
127171
@staticmethod
@@ -285,8 +329,8 @@ def get_model_type(pipeline_or_model):
285329

286330
return m_type
287331
@staticmethod
288-
def log_spark_model(splice_ctx, model, name, run_id):
289-
jvm = splice_ctx.jvm
332+
def log_spark_model(splice_context, model, name, run_id):
333+
jvm = splice_context.jvm
290334
java_import(jvm, "java.io.{BinaryOutputStream, ObjectOutputStream, ByteArrayInputStream}")
291335

292336
if not SparkUtils.is_spark_pipeline(model):
@@ -299,8 +343,8 @@ def log_spark_model(splice_ctx, model, name, run_id):
299343
oos.writeObject(model._to_java())
300344
oos.flush()
301345
oos.close()
302-
insert_artifact(splice_ctx, name, baos.toByteArray(), run_id,
303-
file_ext='sparkmodel') # write the byte stream to the db as a BLOB
346+
insert_artifact(splice_context, name, baos.toByteArray(), run_id,
347+
file_ext='spark') # write the byte stream to the db as a BLOB
304348

305349
@staticmethod
306350
def load_spark_model(splice_ctx, spark_pipeline_blob):
@@ -456,13 +500,13 @@ def get_mleap_model(splice_context, fittedPipe, df, run_id: str):
456500
bash('mkdir /tmp')
457501
# Serialize the Spark model into Mleap format
458502
if f'{run_id}.zip' in rbash('ls /tmp').read():
459-
bash(f'rm /tmp/{run_id}.zip')
503+
remove(f'/tmp/{run_id}.zip')
460504
fittedPipe.serializeToBundle(f"jar:file:///tmp/{run_id}.zip", df)
461505

462506
jvm = splice_context.jvm
463507
java_import(jvm, "com.splicemachine.mlrunner.FileRetriever")
464508
obj = jvm.FileRetriever.loadBundle(f'jar:file:///tmp/{run_id}.zip')
465-
bash(f'rm /tmp/{run_id}.zip"')
509+
remove(f'/tmp/{run_id}.zip')
466510
return obj
467511

468512

@@ -522,7 +566,6 @@ def create_data_table(splice_context, schema_table_name, schema_str, primary_key
522566
f'A model has already been deployed to table {schema_table_name}. We currently only support deploying 1 model per table')
523567
SQL_TABLE = f'CREATE TABLE {schema_table_name} (\n' + schema_str
524568

525-
# FIXME: Add the run_id as a column with constant default value to always be the run_id
526569
pk_cols = ''
527570
for i in primary_key:
528571
# If pk is already in the schema_string, don't add another column. PK may be an existing value

0 commit comments

Comments
 (0)