22from collections import defaultdict
33from contextlib import contextmanager
44from io import BytesIO
5- from os import path
5+ from os import path , remove
6+ from shutil import rmtree
67from zipfile import ZipFile
8+ from sys import version as py_version
79
810import gorilla
911import mlflow
1012import requests
1113from requests .auth import HTTPBasicAuth
1214from mleap .pyspark import spark_support
15+ import h2o
16+ import pyspark
1317
1418from splicemachine .mlflow_support .utilities import *
1519from splicemachine .spark .context import PySpliceContext
2125_CLIENT = mlflow .tracking .MlflowClient (tracking_uri = _TRACKING_URL )
2226
2327_GORILLA_SETTINGS = gorilla .Settings (allow_hit = True , store_hit = True )
24-
28+ _PYTHON_VERSION = py_version . split ( '|' )[ 0 ]. strip ()
2529
2630def _mlflow_patch (name ):
2731 """
@@ -110,8 +114,8 @@ def _lm(key, value):
110114 mlflow .log_metric (key , value )
111115
112116
113- @_mlflow_patch ('log_spark_model ' )
114- def _log_spark_model (model , name = 'model' ):
117+ @_mlflow_patch ('log_model ' )
118+ def _log_model (model , name = 'model' ):
115119 """
116120 Log a fitted spark pipeline or model
117121 :param model: (PipelineModel or Model) is the fitted Spark Model/Pipeline to store
@@ -120,26 +124,27 @@ def _log_spark_model(model, name='model'):
120124 """
121125 _check_for_splice_ctx ()
122126 if _get_current_run_data ().tags .get ('splice.model_name' ): # this function has already run
123- raise Exception ("Only one model is permitted per run." )
127+ raise SpliceMachineException ("Only one model is permitted per run." )
124128
125129 mlflow .set_tag ('splice.model_name' , name ) # read in backend for deployment
126-
127- jvm = mlflow ._splice_context .jvm
128- java_import (jvm , "java.io.{BinaryOutputStream, ObjectOutputStream, ByteArrayInputStream}" )
129-
130- if not SparkUtils .is_spark_pipeline (model ):
131- model = PipelineModel (
132- stages = [model ]
133- ) # create a pipeline with only the model if a model is passed in
134-
135- baos = jvm .java .io .ByteArrayOutputStream () # serialize the PipelineModel to a byte array
136- oos = jvm .java .io .ObjectOutputStream (baos )
137- oos .writeObject (model ._to_java ())
138- oos .flush ()
139- oos .close ()
140- insert_artifact (mlflow ._splice_context , name , baos .toByteArray (), mlflow .active_run ().info .run_uuid ,
141- file_ext = 'sparkmodel' ) # write the byte stream to the db as a BLOB
142-
130+ model_class = str (model .__class__ )
131+ mlflow .set_tag ('splice.model_type' , model_class )
132+ mlflow .set_tag ('splice.model_py_version' , _PYTHON_VERSION )
133+
134+ run_id = mlflow .active_run ().info .run_uuid
135+ if 'h2o' in model_class .lower ():
136+ mlflow .set_tag ('splice.h2o_version' , h2o .__version__ )
137+ model_path = h2o .save_model (model = model , path = '/tmp/model' , force = True )
138+ with open (model_path , 'rb' ) as artifact :
139+ byte_stream = bytearray (bytes (artifact .read ()))
140+ insert_artifact (mlflow ._splice_context , name , byte_stream , run_id , file_ext = 'h2omodel' )
141+ rmtree ('/tmp/model' )
142+
143+ elif 'spark' in model_class .lower ():
144+ mlflow .set_tag ('splice.spark_version' , pyspark .__version__ )
145+ SparkUtils .log_spark_model (mlflow ._splice_context , model , name , run_id = run_id )
146+ else :
147+ raise SpliceMachineException ('Currently we only support logging Spark and H2O models.' )
143148
144149@_mlflow_patch ('start_run' )
145150def _start_run (run_id = None , tags = None , experiment_id = None , run_name = None , nested = False ):
@@ -284,21 +289,18 @@ def _download_artifact(name, local_path, run_id=None):
284289 _check_for_splice_ctx ()
285290 file_ext = path .splitext (local_path )[1 ]
286291
287- if not file_ext :
288- raise ValueError ('local_path variable must contain the file extension!' )
289-
290292 run_id = run_id or mlflow .active_run ().info .run_uuid
291- blob_data = SparkUtils .retrieve_artifact_stream (mlflow ._splice_context , run_id , name )
292- if file_ext == '.zip' :
293- zip_file = ZipFile ( BytesIO ( blob_data ))
294- zip_file . extractall ()
295- else :
296- with open (local_path , 'wb' ) as artifact_file :
293+ blob_data , f_etx = SparkUtils .retrieve_artifact_stream (mlflow ._splice_context , run_id , name )
294+
295+ if not file_ext : # If the user didn't provide the file (ie entered . as the local_path), fill it in for them
296+ local_path += f'/ { name } . { f_etx } '
297+
298+ with open (local_path , 'wb' ) as artifact_file :
297299 artifact_file .write (blob_data )
298300
299301
300- @_mlflow_patch ('load_spark_model ' )
301- def _load_spark_model (run_id = None , name = 'model' ):
302+ @_mlflow_patch ('load_model ' )
303+ def _load_model (run_id = None , name = 'model' ):
302304 """
303305 Download a model from database
304306 and load it into Spark
@@ -308,17 +310,17 @@ def _load_spark_model(run_id=None, name='model'):
308310 """
309311 _check_for_splice_ctx ()
310312 run_id = run_id or mlflow .active_run ().info .run_uuid
311- spark_pipeline_blob = SparkUtils .retrieve_artifact_stream (mlflow ._splice_context , run_id , name )
312- bis = mlflow ._splice_context .jvm .java .io .ByteArrayInputStream (spark_pipeline_blob )
313- ois = mlflow ._splice_context .jvm .java .io .ObjectInputStream (bis )
314- pipeline = PipelineModel ._from_java (ois .readObject ()) # convert object from Java
315- # PipelineModel to Python PipelineModel
316- ois .close ()
313+ model_blob , file_ext = SparkUtils .retrieve_artifact_stream (mlflow ._splice_context , run_id , name )
317314
318- if len ( pipeline . stages ) == 1 and SparkUtils . is_spark_pipeline ( pipeline . stages [ 0 ]) :
319- pipeline = pipeline . stages [ 0 ]
315+ if file_ext == 'sparkmodel' :
316+ model = SparkUtils . load_spark_model ( mlflow . _splice_context , model_blob )
320317
321- return pipeline
318+ elif file_ext == 'h2omodel' :
319+ with open ('/tmp/model' , 'wb' ) as file :
320+ file .write (model_blob )
321+ model = h2o .load_model ('/tmp/model' )
322+ rmtree ('/tmp/model' )
323+ return model
322324
323325
324326@_mlflow_patch ('log_artifact' )
@@ -510,11 +512,6 @@ def _deploy_db(fittedPipe, df, db_schema_name, db_table_name, primary_key,
510512 db_table_name = db_table_name if db_table_name else f'data_{ run_id } '
511513 schema_table_name = f'{ db_schema_name } .{ db_table_name } ' if db_schema_name else db_table_name
512514
513- # Get the VectorAssembler so we can get the features of the model
514- # FIXME: this might not be correct. If transformations are made before hitting the VectorAssembler, they
515- # FIXME: Also need to be included in the columns of the table. We need the df columns + VectorAssembler inputCols
516- # FIXME: We can do something similar to the log_feature_transformations function to get necessary columns
517- # FIXME: Or, this may just be df.columns ...
518515 feature_columns = df .columns
519516 # Get the datatype of each column in the dataframe
520517 schema_types = {str (i .name ): re .sub ("[0-9,()]" , "" , str (i .dataType )) for i in df .schema }
@@ -562,7 +559,7 @@ def _deploy_db(fittedPipe, df, db_schema_name, db_table_name, primary_key,
562559
563560 # Create table 2: DATA_PREDS
564561 print ('Creating prediction table ...' , end = ' ' )
565- create_data_preds_table (mlflow ._splice_context , schema_table_name , classes , primary_key , modelType , verbose )
562+ create_data_preds_table (mlflow ._splice_context , run_id , schema_table_name , classes , primary_key , modelType , verbose )
566563 print ('Done.' )
567564
568565 # Create Trigger 1: (model prediction)
@@ -594,7 +591,7 @@ def apply_patches():
594591 ALL GORILLA PATCHES SHOULD BE PREFIXED WITH "_" BEFORE THEIR DESTINATION IN MLFLOW
595592 """
596593 targets = [_register_splice_context , _lp , _lm , _timer , _log_artifact , _log_feature_transformations ,
597- _log_model_params , _log_pipeline_stages , _log_spark_model , _load_spark_model , _download_artifact ,
594+ _log_model_params , _log_pipeline_stages , _log_model , _load_model , _download_artifact ,
598595 _start_run , _current_run_id , _current_exp_id , _deploy_aws , _deploy_azure , _deploy_db , _login_director ]
599596
600597 for target in targets :
0 commit comments