Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit 4ec926a

Browse files
authored
Merge pull request #2 from splicemachine/refactor_and_mlflow_advanced
Refactor and mlflow advanced
2 parents 5dea5ec + 3fde9b3 commit 4ec926a

File tree

4 files changed

+537
-510
lines changed

4 files changed

+537
-510
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@
3333
]
3434
setup(
3535
name="splicemachine",
36-
version="0.3.1",
36+
version="0.4.0",
3737
install_requires=dependencies,
3838
packages=find_packages(),
3939
license='Apache License, Version 2.0',
4040
long_description=open('README.md').read(),
4141
author="Splice Machine, Inc.",
42+
author_email="[email protected]",
4243
description="This package contains all of the classes and functions you need to interact with Splice Machine's scale out, Hadoop on SQL RDBMS from Python. It also contains several machine learning utilities for use with Apache Spark.",
4344
url="https://github.com/splicemachine/pysplice/"
4445
)

splicemachine/ml/management.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import mlflow
2+
import mlflow.sklearn
3+
import mlflow.spark
4+
import mlflow.h2o
5+
from mlflow.tracking import MlflowClient
6+
7+
class MLManager(MlflowClient):
8+
"""
9+
A class for managing your MLFlow Runs/Experiments
10+
"""
11+
12+
def __init__(self, _tracking_uri=mlflow.get_tracking_uri(), _artifact_uri=mlflow.get_artifact_uri()):
13+
MlflowClient.__init__(self, _tracking_uri)
14+
self.artifact_uri = _artifact_uri
15+
self.active_run = None
16+
self.active_experiment = None
17+
18+
@staticmethod
19+
def __removekey(d, key):
20+
"""
21+
Remove a key from a dictionary
22+
"""
23+
r = dict(d)
24+
del r[key]
25+
return r
26+
27+
def set_active_experiment(self, experiment_name):
28+
"""
29+
Set the active experiment of which all new runs will be created under
30+
Does not apply to already created runs
31+
32+
:param experiment_name: either an integer (experiment id) or a string (experiment name)
33+
"""
34+
35+
if isinstance(experiment_name, str):
36+
self.active_experiment = self.get_experiment_by_name(experiment_name)
37+
38+
elif isinstance(experiment_name, int):
39+
self.active_experiment = self.get_experiment(experiment_name)
40+
41+
42+
def create_new_run(self, user_id="splice"):
43+
"""
44+
Create a new run in the active experiment and set it to be active
45+
:param user_id: the user who creates the run in the MLFlow UI
46+
"""
47+
if not self.active_experiment:
48+
raise Exception("You must set an experiment before you can create a run. Use MLFlowManager.set_active_experiment")
49+
50+
self.active_run = self.create_run(self.active_experiment.experiment_id, user_id=user_id)
51+
52+
def set_active_run(self, run_id):
53+
"""
54+
Set the active run to a previous run (allows you to log metadata for completed run)
55+
:param run_id: the run UUID for the previous run
56+
"""
57+
self.active_run = self.get_run(run_id)
58+
59+
def __log_param(self, *args, **kwargs):
60+
super(MLManager, self).log_param(self.active_run.info.run_uuid, *args, **kwargs)
61+
62+
def log_param(self, *args, **kwargs):
63+
"""
64+
Log a parameter for the active run
65+
"""
66+
self.__log_param(*args, **kwargs)
67+
68+
def __set_tag(self, *args, **kwargs):
69+
super(MLManager, self).log_tag(self.active_run.info.run_uuid, *args, **kwargs)
70+
71+
def set_tag(self, *args, **kwargs):
72+
"""
73+
Set a tag for the active run
74+
"""
75+
self.__set_tag(*args, **kwargs)
76+
77+
def __log_metric(self, *args, **kwargs):
78+
super(MLManager, self).log_metric(self.active_run.info.run_uuid, *args, **kwargs)
79+
80+
def log_metric(self, *args, **kwargs):
81+
"""
82+
Log a metric for the active run
83+
"""
84+
self.__log_metric(*args, **kwargs)
85+
86+
def __log_artifact(self, *args, **kwargs):
87+
super(MLManager, self).log_artifact(self.active_run.info.run_uuid, *args, **kwargs)
88+
89+
def log_artifact(self, *args, **kwargs):
90+
"""
91+
Log an artifact for the active run
92+
"""
93+
self.__log_artifact(*args, **kwargs)
94+
95+
def __log_artifacts(self, *args, **kwargs):
96+
super(MLManager, self).log_artifacts(self.active_run.info.run_uuid, *args, **kwargs)
97+
98+
def log_artifacts(self, *args, **kwargs):
99+
"""
100+
Log artifacts for the active run
101+
"""
102+
self.__log_artifacts(*args, **kwargs)
103+
104+
def log_model(self, model, module):
105+
"""
106+
Log a model for the active run
107+
:param model: the fitted model/pipeline (in spark) to log
108+
:param module: the module that this is part of (mlflow.spark, mlflow.sklearn etc)
109+
"""
110+
with mlflow.start_run(run_uuid=self.active_run.info.run_uuid):
111+
module.log_model(model, "spark_model")
112+
113+
def log_spark_model(self, model):
114+
"""
115+
Log a spark model
116+
:param model: the fitted pipeline/model to log
117+
"""
118+
self.log_model(model, mlflow.spark)

0 commit comments

Comments
 (0)