11import time
22from collections import defaultdict
33from contextlib import contextmanager
4- from io import BytesIO
5- from os import path , remove
6- from shutil import rmtree
7- from zipfile import ZipFile
4+ from os import path
85from sys import version as py_version
96
107import gorilla
118import mlflow
129import requests
1310from requests .auth import HTTPBasicAuth
1411from mleap .pyspark import spark_support
15- import h2o
1612import pyspark
13+ from pyspark .ml .base import Estimator as SparkModel
14+ import sklearn
15+ from sklearn .base import BaseEstimator as ScikitModel
16+ from tensorflow import __version__ as tf_version
17+ from tensorflow .keras import __version__ as keras_version
18+ from tensorflow .keras import Model as KerasModel
1719
1820from splicemachine .mlflow_support .constants import *
1921from splicemachine .mlflow_support .utilities import *
2426_TRACKING_URL = get_pod_uri ("mlflow" , "5001" , _TESTING )
2527
2628_CLIENT = mlflow .tracking .MlflowClient (tracking_uri = _TRACKING_URL )
29+ mlflow .client = _CLIENT
2730
2831_GORILLA_SETTINGS = gorilla .Settings (allow_hit = True , store_hit = True )
2932_PYTHON_VERSION = py_version .split ('|' )[0 ].strip ()
@@ -50,6 +53,23 @@ def _get_current_run_data():
5053 return _CLIENT .get_run (mlflow .active_run ().info .run_id ).data
5154
5255
56+ @_mlflow_patch ('get_run_ids_by_name' )
57+ def _get_run_ids_by_name (run_name , experiment_id = None ):
58+ """
59+ Gets a run id from the run name. If there are multiple runs with the same name, all run IDs are returned
60+ :param run_name: The name of the run
61+ :param experiment_id: The experiment to search in. If None, all experiments are searched
62+ :return: List of run ids
63+ """
64+ exps = [experiment_id ] if experiment_id else _CLIENT .list_experiments ()
65+ run_ids = []
66+ for exp in exps :
67+ for run in _CLIENT .search_runs (exp .experiment_id ):
68+ if run_name == run .data .tags ['mlflow.runName' ]:
69+ run_ids .append (run .data .tags ['Run ID' ])
70+ return run_ids
71+
72+
5373@_mlflow_patch ('register_splice_context' )
5474def _register_splice_context (splice_context ):
5575 """
@@ -70,7 +90,7 @@ def _check_for_splice_ctx():
7090
7191 if not hasattr (mlflow , '_splice_context' ):
7292 raise SpliceMachineException (
73- "You must run `mlflow.register_splice_context(py_splice_context ) before "
93+ "You must run `mlflow.register_splice_context(pysplice_context ) before "
7494 "you can run this mlflow operation!"
7595 )
7696
@@ -135,17 +155,26 @@ def _log_model(model, name='model'):
135155 run_id = mlflow .active_run ().info .run_uuid
136156 if 'h2o' in model_class .lower ():
137157 mlflow .set_tag ('splice.h2o_version' , h2o .__version__ )
138- model_path = h2o .save_model (model = model , path = '/tmp/model' , force = True )
139- with open (model_path , 'rb' ) as artifact :
140- byte_stream = bytearray (bytes (artifact .read ()))
141- insert_artifact (mlflow ._splice_context , name , byte_stream , run_id , file_ext = 'h2omodel' )
142- rmtree ('/tmp/model' )
158+ H2OUtils .log_h2o_model (mlflow ._splice_context , model , name , run_id )
143159
144- elif 'spark' in model_class . lower ( ):
160+ elif isinstance ( model , SparkModel ):
145161 mlflow .set_tag ('splice.spark_version' , pyspark .__version__ )
146- SparkUtils .log_spark_model (mlflow ._splice_context , model , name , run_id = run_id )
162+ SparkUtils .log_spark_model (mlflow ._splice_context , model , name , run_id )
163+
164+ elif isinstance (model , ScikitModel ):
165+ mlflow .set_tag ('splice.sklearn_version' , sklearn .__version__ )
166+ SKUtils .log_sklearn_model (mlflow ._splice_context , model , name , run_id )
167+
168+ elif isinstance (model , KerasModel ): # We can't handle keras models with a different backend
169+ mlflow .set_tag ('splice.keras_version' , keras_version )
170+ mlflow .set_tag ('splice.tf_version' , tf_version )
171+ KerasUtils .log_keras_model (mlflow ._splice_context , model , name , run_id )
172+
173+
147174 else :
148- raise SpliceMachineException ('Currently we only support logging Spark and H2O models.' )
175+ raise SpliceMachineException ('Model type not supported for logging.'
176+ 'Currently we support logging Spark, H2O, SKLearn and Keras (TF backend) models.'
177+ 'You can save your model to disk, zip it and run mlflow.log_artifact to save.' )
149178
150179@_mlflow_patch ('start_run' )
151180def _start_run (run_id = None , tags = None , experiment_id = None , run_name = None , nested = False ):
@@ -313,14 +342,18 @@ def _load_model(run_id=None, name='model'):
313342 run_id = run_id or mlflow .active_run ().info .run_uuid
314343 model_blob , file_ext = SparkUtils .retrieve_artifact_stream (mlflow ._splice_context , run_id , name )
315344
316- if file_ext == 'sparkmodel' :
345+ if file_ext == FileExtensions . spark :
317346 model = SparkUtils .load_spark_model (mlflow ._splice_context , model_blob )
347+ elif file_ext == FileExtensions .h2o :
348+ model = H2OUtils .load_h2o_model (model_blob )
349+ elif file_ext == FileExtensions .sklearn :
350+ model = SKUtils .load_sklearn_model (model_blob )
351+ elif file_ext == FileExtensions .keras :
352+ model = KerasUtils .load_keras_model (model_blob )
353+ else :
354+ raise SpliceMachineException (f'Model extension { file_ext } was not a supported model type. '
355+ f'Supported model extensions are { FileExtensions .get_valid ()} ' )
318356
319- elif file_ext == 'h2omodel' :
320- with open ('/tmp/model' , 'wb' ) as file :
321- file .write (model_blob )
322- model = h2o .load_model ('/tmp/model' )
323- remove ('/tmp/model' )
324357 return model
325358
326359
@@ -585,7 +618,8 @@ def apply_patches():
585618 """
586619 targets = [_register_splice_context , _lp , _lm , _timer , _log_artifact , _log_feature_transformations ,
587620 _log_model_params , _log_pipeline_stages , _log_model , _load_model , _download_artifact ,
588- _start_run , _current_run_id , _current_exp_id , _deploy_aws , _deploy_azure , _deploy_db , _login_director ]
621+ _start_run , _current_run_id , _current_exp_id , _deploy_aws , _deploy_azure , _deploy_db , _login_director ,
622+ _get_run_ids_by_name ]
589623
590624 for target in targets :
591625 gorilla .apply (gorilla .Patch (mlflow , target .__name__ .lstrip ('_' ), target , settings = _GORILLA_SETTINGS ))
0 commit comments