Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit a3453cd

Browse files
author
Epstein
authored
Dbaas 3643 (#49)
* DBAAS-3642: made spark model saving and loading generic * DBAAS-3643: model class is object, not string * accounting for new columns * fix version column
1 parent 4d68e10 commit a3453cd

File tree

2 files changed

+88
-55
lines changed

2 files changed

+88
-55
lines changed

splicemachine/mlflow_support/mlflow_support.py

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22
from collections import defaultdict
33
from contextlib import contextmanager
44
from io import BytesIO
5-
from os import path
5+
from os import path, remove
6+
from shutil import rmtree
67
from zipfile import ZipFile
8+
from sys import version as py_version
79

810
import gorilla
911
import mlflow
1012
import requests
1113
from requests.auth import HTTPBasicAuth
1214
from mleap.pyspark import spark_support
15+
import h2o
16+
import pyspark
1317

1418
from splicemachine.mlflow_support.utilities import *
1519
from splicemachine.spark.context import PySpliceContext
@@ -21,7 +25,7 @@
2125
_CLIENT = mlflow.tracking.MlflowClient(tracking_uri=_TRACKING_URL)
2226

2327
_GORILLA_SETTINGS = gorilla.Settings(allow_hit=True, store_hit=True)
24-
28+
_PYTHON_VERSION = py_version.split('|')[0].strip()
2529

2630
def _mlflow_patch(name):
2731
"""
@@ -110,8 +114,8 @@ def _lm(key, value):
110114
mlflow.log_metric(key, value)
111115

112116

113-
@_mlflow_patch('log_spark_model')
114-
def _log_spark_model(model, name='model'):
117+
@_mlflow_patch('log_model')
118+
def _log_model(model, name='model'):
115119
"""
116120
Log a fitted spark pipeline or model
117121
:param model: (PipelineModel or Model) is the fitted Spark Model/Pipeline to store
@@ -120,26 +124,27 @@ def _log_spark_model(model, name='model'):
120124
"""
121125
_check_for_splice_ctx()
122126
if _get_current_run_data().tags.get('splice.model_name'): # this function has already run
123-
raise Exception("Only one model is permitted per run.")
127+
raise SpliceMachineException("Only one model is permitted per run.")
124128

125129
mlflow.set_tag('splice.model_name', name) # read in backend for deployment
126-
127-
jvm = mlflow._splice_context.jvm
128-
java_import(jvm, "java.io.{BinaryOutputStream, ObjectOutputStream, ByteArrayInputStream}")
129-
130-
if not SparkUtils.is_spark_pipeline(model):
131-
model = PipelineModel(
132-
stages=[model]
133-
) # create a pipeline with only the model if a model is passed in
134-
135-
baos = jvm.java.io.ByteArrayOutputStream() # serialize the PipelineModel to a byte array
136-
oos = jvm.java.io.ObjectOutputStream(baos)
137-
oos.writeObject(model._to_java())
138-
oos.flush()
139-
oos.close()
140-
insert_artifact(mlflow._splice_context, name, baos.toByteArray(), mlflow.active_run().info.run_uuid,
141-
file_ext='sparkmodel') # write the byte stream to the db as a BLOB
142-
130+
model_class = str(model.__class__)
131+
mlflow.set_tag('splice.model_type', model_class)
132+
mlflow.set_tag('splice.model_py_version', _PYTHON_VERSION)
133+
134+
run_id = mlflow.active_run().info.run_uuid
135+
if 'h2o' in model_class.lower():
136+
mlflow.set_tag('splice.h2o_version', h2o.__version__)
137+
model_path = h2o.save_model(model=model, path='/tmp/model', force=True)
138+
with open(model_path, 'rb') as artifact:
139+
byte_stream = bytearray(bytes(artifact.read()))
140+
insert_artifact(mlflow._splice_context, name, byte_stream, run_id, file_ext='h2omodel')
141+
rmtree('/tmp/model')
142+
143+
elif 'spark' in model_class.lower():
144+
mlflow.set_tag('splice.spark_version', pyspark.__version__)
145+
SparkUtils.log_spark_model(mlflow._splice_context, model, name, run_id=run_id)
146+
else:
147+
raise SpliceMachineException('Currently we only support logging Spark and H2O models.')
143148

144149
@_mlflow_patch('start_run')
145150
def _start_run(run_id=None, tags=None, experiment_id=None, run_name=None, nested=False):
@@ -284,21 +289,18 @@ def _download_artifact(name, local_path, run_id=None):
284289
_check_for_splice_ctx()
285290
file_ext = path.splitext(local_path)[1]
286291

287-
if not file_ext:
288-
raise ValueError('local_path variable must contain the file extension!')
289-
290292
run_id = run_id or mlflow.active_run().info.run_uuid
291-
blob_data = SparkUtils.retrieve_artifact_stream(mlflow._splice_context, run_id, name)
292-
if file_ext == '.zip':
293-
zip_file = ZipFile(BytesIO(blob_data))
294-
zip_file.extractall()
295-
else:
296-
with open(local_path, 'wb') as artifact_file:
293+
blob_data, f_etx = SparkUtils.retrieve_artifact_stream(mlflow._splice_context, run_id, name)
294+
295+
if not file_ext: # If the user didn't provide the file (ie entered . as the local_path), fill it in for them
296+
local_path += f'/{name}.{f_etx}'
297+
298+
with open(local_path, 'wb') as artifact_file:
297299
artifact_file.write(blob_data)
298300

299301

300-
@_mlflow_patch('load_spark_model')
301-
def _load_spark_model(run_id=None, name='model'):
302+
@_mlflow_patch('load_model')
303+
def _load_model(run_id=None, name='model'):
302304
"""
303305
Download a model from database
304306
and load it into Spark
@@ -308,17 +310,17 @@ def _load_spark_model(run_id=None, name='model'):
308310
"""
309311
_check_for_splice_ctx()
310312
run_id = run_id or mlflow.active_run().info.run_uuid
311-
spark_pipeline_blob = SparkUtils.retrieve_artifact_stream(mlflow._splice_context, run_id, name)
312-
bis = mlflow._splice_context.jvm.java.io.ByteArrayInputStream(spark_pipeline_blob)
313-
ois = mlflow._splice_context.jvm.java.io.ObjectInputStream(bis)
314-
pipeline = PipelineModel._from_java(ois.readObject()) # convert object from Java
315-
# PipelineModel to Python PipelineModel
316-
ois.close()
313+
model_blob, file_ext = SparkUtils.retrieve_artifact_stream(mlflow._splice_context, run_id, name)
317314

318-
if len(pipeline.stages) == 1 and SparkUtils.is_spark_pipeline(pipeline.stages[0]):
319-
pipeline = pipeline.stages[0]
315+
if file_ext == 'sparkmodel':
316+
model = SparkUtils.load_spark_model(mlflow._splice_context, model_blob)
320317

321-
return pipeline
318+
elif file_ext == 'h2omodel':
319+
with open('/tmp/model', 'wb') as file:
320+
file.write(model_blob)
321+
model = h2o.load_model('/tmp/model')
322+
rmtree('/tmp/model')
323+
return model
322324

323325

324326
@_mlflow_patch('log_artifact')
@@ -510,11 +512,6 @@ def _deploy_db(fittedPipe, df, db_schema_name, db_table_name, primary_key,
510512
db_table_name = db_table_name if db_table_name else f'data_{run_id}'
511513
schema_table_name = f'{db_schema_name}.{db_table_name}' if db_schema_name else db_table_name
512514

513-
# Get the VectorAssembler so we can get the features of the model
514-
# FIXME: this might not be correct. If transformations are made before hitting the VectorAssembler, they
515-
# FIXME: Also need to be included in the columns of the table. We need the df columns + VectorAssembler inputCols
516-
# FIXME: We can do something similar to the log_feature_transformations function to get necessary columns
517-
# FIXME: Or, this may just be df.columns ...
518515
feature_columns = df.columns
519516
# Get the datatype of each column in the dataframe
520517
schema_types = {str(i.name): re.sub("[0-9,()]", "", str(i.dataType)) for i in df.schema}
@@ -562,7 +559,7 @@ def _deploy_db(fittedPipe, df, db_schema_name, db_table_name, primary_key,
562559

563560
# Create table 2: DATA_PREDS
564561
print('Creating prediction table ...', end=' ')
565-
create_data_preds_table(mlflow._splice_context, schema_table_name, classes, primary_key, modelType, verbose)
562+
create_data_preds_table(mlflow._splice_context, run_id, schema_table_name, classes, primary_key, modelType, verbose)
566563
print('Done.')
567564

568565
# Create Trigger 1: (model prediction)
@@ -594,7 +591,7 @@ def apply_patches():
594591
ALL GORILLA PATCHES SHOULD BE PREFIXED WITH "_" BEFORE THEIR DESTINATION IN MLFLOW
595592
"""
596593
targets = [_register_splice_context, _lp, _lm, _timer, _log_artifact, _log_feature_transformations,
597-
_log_model_params, _log_pipeline_stages, _log_spark_model, _load_spark_model, _download_artifact,
594+
_log_model_params, _log_pipeline_stages, _log_model, _load_model, _download_artifact,
598595
_start_run, _current_run_id, _current_exp_id, _deploy_aws, _deploy_azure, _deploy_db, _login_director]
599596

600597
for target in targets:

splicemachine/mlflow_support/utilities.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from os import environ as env_vars, popen as rbash, system as bash
1+
from os import environ as env_vars, popen as rbash, system as bash, remove
22
from sys import getsizeof
33
import re
44

@@ -9,7 +9,9 @@
99
from splicemachine.spark.constants import SQL_TYPES
1010
from splicemachine.mlflow_support.constants import SparkModelType
1111
from mleap.pyspark.spark_support import SimpleSparkSerializer
12+
import h2o
1213

14+
from pyspark.ml.pipeline import PipelineModel
1315

1416
class SpliceMachineException(Exception):
1517
pass
@@ -18,9 +20,9 @@ class SpliceMachineException(Exception):
1820
class SQL:
1921
MLMANAGER_SCHEMA = 'MLMANAGER'
2022
ARTIFACT_INSERT_SQL = f'INSERT INTO {MLMANAGER_SCHEMA}.ARTIFACTS (run_uuid, name, "size", "binary", file_extension) VALUES (?, ?, ?, ?, ?)'
21-
ARTIFACT_RETRIEVAL_SQL = 'SELECT "binary" FROM ' + f'{MLMANAGER_SCHEMA}.' + 'ARTIFACTS WHERE name=\'{name}\' ' \
23+
ARTIFACT_RETRIEVAL_SQL = 'SELECT "binary", file_extension FROM ' + f'{MLMANAGER_SCHEMA}.' + 'ARTIFACTS WHERE name=\'{name}\' ' \
2224
'AND run_uuid=\'{runid}\''
23-
MODEL_INSERT_SQL = f'INSERT INTO {MLMANAGER_SCHEMA}.MODELS(RUN_UUID, MODEL) VALUES (?, ?)'
25+
MODEL_INSERT_SQL = f'INSERT INTO {MLMANAGER_SCHEMA}.MODELS(RUN_UUID, MODEL, LIBRARY, "version") VALUES (?, ?, ?, ?)'
2426
MODEL_RETRIEVAL_SQL = 'SELECT MODEL FROM {MLMANAGER_SCHEMA}.MODELS WHERE RUN_UUID=\'{run_uuid}\''
2527

2628

@@ -137,7 +139,7 @@ def retrieve_artifact_stream(splice_context, run_id, name):
137139
try:
138140
return splice_context.df(
139141
SQL.ARTIFACT_RETRIEVAL_SQL.format(name=name, runid=run_id)
140-
).collect()[0][0]
142+
).collect()[0]
141143
except IndexError as e:
142144
raise Exception(f"Unable to find the artifact with the given run id {run_id} and name {name}")
143145

@@ -185,7 +187,37 @@ def get_model_type(pipeline_or_model):
185187
m_type = SparkModelType.CLUSTERING_WO_PROB
186188

187189
return m_type
190+
@staticmethod
191+
def log_spark_model(splice_ctx, model, name, run_id):
192+
jvm = splice_ctx.jvm
193+
java_import(jvm, "java.io.{BinaryOutputStream, ObjectOutputStream, ByteArrayInputStream}")
194+
195+
if not SparkUtils.is_spark_pipeline(model):
196+
model = PipelineModel(
197+
stages=[model]
198+
) # create a pipeline with only the model if a model is passed in
199+
200+
baos = jvm.java.io.ByteArrayOutputStream() # serialize the PipelineModel to a byte array
201+
oos = jvm.java.io.ObjectOutputStream(baos)
202+
oos.writeObject(model._to_java())
203+
oos.flush()
204+
oos.close()
205+
insert_artifact(splice_ctx, name, baos.toByteArray(), run_id,
206+
file_ext='sparkmodel') # write the byte stream to the db as a BLOB
188207

208+
@staticmethod
209+
def load_spark_model(splice_ctx, spark_pipeline_blob):
210+
jvm = splice_ctx.jvm
211+
bis = jvm.java.io.ByteArrayInputStream(spark_pipeline_blob)
212+
ois = jvm.java.io.ObjectInputStream(bis)
213+
pipeline = PipelineModel._from_java(ois.readObject()) # convert object from Java
214+
# PipelineModel to Python PipelineModel
215+
ois.close()
216+
217+
if len(pipeline.stages) == 1 and not SparkUtils.is_spark_pipeline(pipeline.stages[0]):
218+
pipeline = pipeline.stages[0]
219+
220+
return pipeline
189221

190222
def find_inputs_by_output(dictionary, value):
191223
"""
@@ -231,6 +263,9 @@ def insert_model(splice_context, run_id, byte_array):
231263
prepared_statement = db_connection.prepareStatement(SQL.MODEL_INSERT_SQL)
232264
prepared_statement.setString(1, run_id)
233265
prepared_statement.setBinaryStream(2, binary_input_stream)
266+
# FIXME: Dynamically set this per model type (only mleap for now)
267+
prepared_statement.setString(3, 'MLEAP')
268+
prepared_statement.setString(4, '0.15.0')
234269

235270
prepared_statement.execute()
236271
prepared_statement.close()
@@ -368,12 +403,13 @@ def create_data_table(splice_context, schema_table_name, schema_str, primary_key
368403
splice_context.execute(SQL_TABLE)
369404

370405

371-
def create_data_preds_table(splice_context, schema_table_name, classes, primary_key,
406+
def create_data_preds_table(splice_context, run_id, schema_table_name, classes, primary_key,
372407
modelType, verbose):
373408
"""
374409
Creates the data prediction table that holds the prediction for the rows of the data table
375410
:param splice_context: pysplicectx
376411
:param schema_table_name: (str) the schema.table to create the table under
412+
:param run_id: (str) the run_id for this model
377413
:param classes: (List[str]) the labels of the model (if they exist)
378414
:param primary_key: List[Tuple[str,str]] column name, SQL datatype for the primary key(s) of the table
379415
:param modelType: (ModelType) Whether the model is a Regression, Classification or Clustering (with/without probabilities)
@@ -386,8 +422,8 @@ def create_data_preds_table(splice_context, schema_table_name, classes, primary_
386422
SQL_PRED_TABLE = f'''CREATE TABLE {schema_table_name}_PREDS (
387423
\tCUR_USER VARCHAR(50) DEFAULT CURRENT_USER,
388424
\tEVAL_TIME TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
425+
\tRUN_ID VARCHAR(50) DEFAULT \'{run_id}\',
389426
'''
390-
# FIXME: Add the run_id as a column with constant default value to always be the run_id
391427
pk_cols = ''
392428
for i in primary_key:
393429
SQL_PRED_TABLE += f'\t{i[0]} {i[1]},\n'

0 commit comments

Comments
 (0)