Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit 0e5091d

Browse files
author
Ben Epstein
authored
Dbaas 4800 (#91)
* testing notebook save * add patch * wrong ext location * wrapping the active run * revert * wrapped active run * wrapped active run * check for splice before logging at end_run * add extra check for non active run
1 parent 54d7c75 commit 0e5091d

File tree

1 file changed

+67
-3
lines changed

1 file changed

+67
-3
lines changed

splicemachine/mlflow_support/mlflow_support.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,16 @@
4949
from io import BytesIO
5050
from os import path
5151
from sys import version as py_version, stderr
52-
from tempfile import TemporaryDirectory
52+
from tempfile import TemporaryDirectory, NamedTemporaryFile
5353
from zipfile import ZIP_DEFLATED, ZipFile
5454
from typing import Dict, Optional, List, Union
5555

5656
import gorilla
5757
import h2o
5858
import mlflow
5959
import mlflow.pyfunc
60+
from mlflow.tracking.fluent import ActiveRun
61+
from mlflow.entities import RunStatus
6062
import pyspark
6163
import requests
6264
import sklearn
@@ -79,6 +81,15 @@
7981
from splicemachine import SpliceMachineException
8082
from splicemachine.spark.context import PySpliceContext
8183

84+
# For recording notebook history
85+
try:
86+
from IPython import get_ipython
87+
import nbformat as nbf
88+
ipython = get_ipython()
89+
mlflow._notebook_history = True
90+
except:
91+
mlflow._notebook_history = False
92+
8293
_TESTING = os.environ.get("TESTING", False)
8394

8495
try:
@@ -94,6 +105,15 @@
94105
_GORILLA_SETTINGS = gorilla.Settings(allow_hit=True, store_hit=True)
95106
_PYTHON_VERSION = py_version.split('|')[0].strip()
96107

108+
class SpliceActiveRun(ActiveRun):
109+
"""
110+
A wrapped active run for Splice Machine that calls our custom mlflow.end_run, so we can record the notebook
111+
history
112+
"""
113+
def __exit__(self, exc_type, exc_val, exc_tb):
114+
status = RunStatus.FINISHED if exc_type is None else RunStatus.FAILED
115+
mlflow.end_run(RunStatus.to_string(status))
116+
return exc_type is None
97117

98118
def __try_auto_login():
99119
"""
@@ -329,6 +349,50 @@ def _log_model(model, name='model', conda_env=None, model_lib=None):
329349
mlflow.set_tag('splice.model_py_version', _PYTHON_VERSION)
330350

331351

352+
@_mlflow_patch('end_run')
353+
def _end_run(status=RunStatus.to_string(RunStatus.FINISHED), save_html=True):
354+
"""End an active MLflow run (if there is one).
355+
356+
.. code-block:: python
357+
:caption: Example
358+
359+
import mlflow
360+
361+
# Start run and get status
362+
mlflow.start_run()
363+
run = mlflow.active_run()
364+
print("run_id: {}; status: {}".format(run.info.run_id, run.info.status))
365+
366+
# End run and get status
367+
mlflow.end_run()
368+
run = mlflow.get_run(run.info.run_id)
369+
print("run_id: {}; status: {}".format(run.info.run_id, run.info.status))
370+
print("--")
371+
372+
# Check for any active runs
373+
print("Active run: {}".format(mlflow.active_run()))
374+
375+
.. code-block:: text
376+
:caption: Output
377+
378+
run_id: b47ee4563368419880b44ad8535f6371; status: RUNNING
379+
run_id: b47ee4563368419880b44ad8535f6371; status: FINISHED
380+
--
381+
Active run: None
382+
"""
383+
if mlflow._notebook_history and hasattr(mlflow, '_splice_context') and mlflow.active_run():
384+
with NamedTemporaryFile() as temp_file:
385+
nb = nbf.v4.new_notebook()
386+
nb['cells'] = [nbf.v4.new_code_cell(code) for code in ipython.history_manager.input_hist_raw]
387+
nbf.write(nb, temp_file.name)
388+
run_name = mlflow.get_run(mlflow.current_run_id()).to_dictionary()['data']['tags']['mlflow.runName']
389+
mlflow.log_artifact(temp_file.name, name=f'{run_name}_run_log.ipynb')
390+
typ,ext = ('html','html') if save_html else ('script','py')
391+
os.system(f'jupyter nbconvert --to {typ} {temp_file.name}')
392+
mlflow.log_artifact(f'{temp_file.name[:-1]}.{ext}', name=f'{run_name}_run_log.{ext}')
393+
orig = gorilla.get_original_attribute(mlflow, "end_run")
394+
orig(status=status)
395+
332396
@_mlflow_patch('start_run')
333397
def _start_run(run_id=None, tags=None, experiment_id=None, run_name=None, nested=False):
334398
"""
@@ -372,7 +436,7 @@ def _start_run(run_id=None, tags=None, experiment_id=None, run_name=None, nested
372436
if hasattr(mlflow,'_active_training_set'):
373437
mlflow._active_training_set._register_metadata(mlflow)
374438

375-
return active_run
439+
return SpliceActiveRun(active_run)
376440

377441

378442
@_mlflow_patch('log_pipeline_stages')
@@ -924,7 +988,7 @@ def apply_patches():
924988
targets = [_register_feature_store, _register_splice_context, _lp, _lm, _timer, _log_artifact, _log_feature_transformations,
925989
_log_model_params, _log_pipeline_stages, _log_model, _load_model, _download_artifact,
926990
_start_run, _current_run_id, _current_exp_id, _deploy_aws, _deploy_azure, _deploy_db, _login_director,
927-
_get_run_ids_by_name, _get_deployed_models, _deploy_kubernetes, _fetch_logs, _watch_job]
991+
_get_run_ids_by_name, _get_deployed_models, _deploy_kubernetes, _fetch_logs, _watch_job, _end_run]
928992

929993
for target in targets:
930994
gorilla.apply(gorilla.Patch(mlflow, target.__name__.lstrip('_'), target, settings=_GORILLA_SETTINGS))

0 commit comments

Comments
 (0)