|
49 | 49 | from io import BytesIO |
50 | 50 | from os import path |
51 | 51 | from sys import version as py_version, stderr |
52 | | -from tempfile import TemporaryDirectory |
| 52 | +from tempfile import TemporaryDirectory, NamedTemporaryFile |
53 | 53 | from zipfile import ZIP_DEFLATED, ZipFile |
54 | 54 | from typing import Dict, Optional, List, Union |
55 | 55 |
|
56 | 56 | import gorilla |
57 | 57 | import h2o |
58 | 58 | import mlflow |
59 | 59 | import mlflow.pyfunc |
| 60 | +from mlflow.tracking.fluent import ActiveRun |
| 61 | +from mlflow.entities import RunStatus |
60 | 62 | import pyspark |
61 | 63 | import requests |
62 | 64 | import sklearn |
|
79 | 81 | from splicemachine import SpliceMachineException |
80 | 82 | from splicemachine.spark.context import PySpliceContext |
81 | 83 |
|
| 84 | +# For recording notebook history |
| 85 | +try: |
| 86 | + from IPython import get_ipython |
| 87 | + import nbformat as nbf |
| 88 | + ipython = get_ipython() |
| 89 | + mlflow._notebook_history = True |
| 90 | +except: |
| 91 | + mlflow._notebook_history = False |
| 92 | + |
82 | 93 | _TESTING = os.environ.get("TESTING", False) |
83 | 94 |
|
84 | 95 | try: |
|
94 | 105 | _GORILLA_SETTINGS = gorilla.Settings(allow_hit=True, store_hit=True) |
95 | 106 | _PYTHON_VERSION = py_version.split('|')[0].strip() |
96 | 107 |
|
| 108 | +class SpliceActiveRun(ActiveRun): |
| 109 | + """ |
| 110 | + A wrapped active run for Splice Machine that calls our custom mlflow.end_run, so we can record the notebook |
| 111 | + history |
| 112 | + """ |
| 113 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 114 | + status = RunStatus.FINISHED if exc_type is None else RunStatus.FAILED |
| 115 | + mlflow.end_run(RunStatus.to_string(status)) |
| 116 | + return exc_type is None |
97 | 117 |
|
98 | 118 | def __try_auto_login(): |
99 | 119 | """ |
@@ -329,6 +349,50 @@ def _log_model(model, name='model', conda_env=None, model_lib=None): |
329 | 349 | mlflow.set_tag('splice.model_py_version', _PYTHON_VERSION) |
330 | 350 |
|
331 | 351 |
|
| 352 | +@_mlflow_patch('end_run') |
| 353 | +def _end_run(status=RunStatus.to_string(RunStatus.FINISHED), save_html=True): |
| 354 | + """End an active MLflow run (if there is one). |
| 355 | +
|
| 356 | + .. code-block:: python |
| 357 | + :caption: Example |
| 358 | +
|
| 359 | + import mlflow |
| 360 | +
|
| 361 | + # Start run and get status |
| 362 | + mlflow.start_run() |
| 363 | + run = mlflow.active_run() |
| 364 | + print("run_id: {}; status: {}".format(run.info.run_id, run.info.status)) |
| 365 | +
|
| 366 | + # End run and get status |
| 367 | + mlflow.end_run() |
| 368 | + run = mlflow.get_run(run.info.run_id) |
| 369 | + print("run_id: {}; status: {}".format(run.info.run_id, run.info.status)) |
| 370 | + print("--") |
| 371 | +
|
| 372 | + # Check for any active runs |
| 373 | + print("Active run: {}".format(mlflow.active_run())) |
| 374 | +
|
| 375 | + .. code-block:: text |
| 376 | + :caption: Output |
| 377 | +
|
| 378 | + run_id: b47ee4563368419880b44ad8535f6371; status: RUNNING |
| 379 | + run_id: b47ee4563368419880b44ad8535f6371; status: FINISHED |
| 380 | + -- |
| 381 | + Active run: None |
| 382 | + """ |
| 383 | + if mlflow._notebook_history and hasattr(mlflow, '_splice_context') and mlflow.active_run(): |
| 384 | + with NamedTemporaryFile() as temp_file: |
| 385 | + nb = nbf.v4.new_notebook() |
| 386 | + nb['cells'] = [nbf.v4.new_code_cell(code) for code in ipython.history_manager.input_hist_raw] |
| 387 | + nbf.write(nb, temp_file.name) |
| 388 | + run_name = mlflow.get_run(mlflow.current_run_id()).to_dictionary()['data']['tags']['mlflow.runName'] |
| 389 | + mlflow.log_artifact(temp_file.name, name=f'{run_name}_run_log.ipynb') |
| 390 | + typ,ext = ('html','html') if save_html else ('script','py') |
| 391 | + os.system(f'jupyter nbconvert --to {typ} {temp_file.name}') |
| 392 | + mlflow.log_artifact(f'{temp_file.name[:-1]}.{ext}', name=f'{run_name}_run_log.{ext}') |
| 393 | + orig = gorilla.get_original_attribute(mlflow, "end_run") |
| 394 | + orig(status=status) |
| 395 | + |
332 | 396 | @_mlflow_patch('start_run') |
333 | 397 | def _start_run(run_id=None, tags=None, experiment_id=None, run_name=None, nested=False): |
334 | 398 | """ |
@@ -372,7 +436,7 @@ def _start_run(run_id=None, tags=None, experiment_id=None, run_name=None, nested |
372 | 436 | if hasattr(mlflow,'_active_training_set'): |
373 | 437 | mlflow._active_training_set._register_metadata(mlflow) |
374 | 438 |
|
375 | | - return active_run |
| 439 | + return SpliceActiveRun(active_run) |
376 | 440 |
|
377 | 441 |
|
378 | 442 | @_mlflow_patch('log_pipeline_stages') |
@@ -924,7 +988,7 @@ def apply_patches(): |
924 | 988 | targets = [_register_feature_store, _register_splice_context, _lp, _lm, _timer, _log_artifact, _log_feature_transformations, |
925 | 989 | _log_model_params, _log_pipeline_stages, _log_model, _load_model, _download_artifact, |
926 | 990 | _start_run, _current_run_id, _current_exp_id, _deploy_aws, _deploy_azure, _deploy_db, _login_director, |
927 | | - _get_run_ids_by_name, _get_deployed_models, _deploy_kubernetes, _fetch_logs, _watch_job] |
| 991 | + _get_run_ids_by_name, _get_deployed_models, _deploy_kubernetes, _fetch_logs, _watch_job, _end_run] |
928 | 992 |
|
929 | 993 | for target in targets: |
930 | 994 | gorilla.apply(gorilla.Patch(mlflow, target.__name__.lstrip('_'), target, settings=_GORILLA_SETTINGS)) |
|
0 commit comments