Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit 68e63b2

Browse files
author
Epstein
authored
Add support for run_name parameter (#24)
Removed duplicate code and added support for adding a run_name to a new run instead of a randomly generated string
1 parent 3025af8 commit 68e63b2

File tree

1 file changed

+16
-51
lines changed

1 file changed

+16
-51
lines changed

splicemachine/ml/management.py

Lines changed: 16 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,17 @@ def get_pod_uri(pod, port, pod_count=0, testing=False):
2727
raise KeyError(
2828
"Uh Oh! MLFLOW_URL variable was not found... are you running in the Cloud service?")
2929

30-
===
30+
3131
from mlflow.exceptions import MlflowException
3232

3333
def _get_user():
3434
"""
3535
Get the current logged in user to
3636
Jupyter
37-
3837
:return: (str) name of the logged in user
3938
"""
4039
try:
41-
uname = env_vars['JUPYTERHUB_USER']
40+
uname = env_vars.get('JUPYTERHUB_USER') or env_vars['USER']
4241
return uname
4342
except KeyError:
4443
raise Exception("Could not determine current running user. Running MLManager outside of Splice Machine Cloud Jupyter "
@@ -59,7 +58,6 @@ def _readable_pipeline_stage(pipeline_stage):
5958
def _get_stages(pipeline):
6059
"""
6160
Extract the stages from a fit or unfit pipeline
62-
6361
:param pipeline: a fit or unfit Spark pipeline
6462
:return: stages list
6563
"""
@@ -200,7 +198,7 @@ def wrapped(self, *args, **kwargs):
200198
raise Exception("Please either use set_active_experiment or create_experiment "
201199
"to set an active experiment before running this function")
202200
elif not self.active_run:
203-
raise Exception("Please either use set_active_run or create_run to set an active "
201+
raise Exception("Please either use set_active_run or start_run to set an active "
204202
"run before running this function")
205203
else:
206204
return func(self, *args, **kwargs)
@@ -264,41 +262,11 @@ def create_experiment(self, experiment_name, reset=False):
264262
self.active_experiment = self.get_experiment_by_name(experiment_name)
265263
print("Set experiment id=" + str(experiment_id) + " to the active experiment")
266264

267-
def create_experiment(self, experiment_name, reset=False):
268-
"""
269-
Create a new experiment. If the experiment
270-
already exists, it will be set to active experiment.
271-
If the experiment doesn't exist, it will be created
272-
and set to active. If the reset option is set to true
273-
(please use with caution), the runs within the existing
274-
experiment will be deleted
275-
:param experiment_name: (str) the name of the experiment to create
276-
:param reset: (bool) whether or not to overwrite the existing run
277-
"""
278-
experiment = self.get_experiment_by_name(experiment_name)
279-
if experiment:
280-
print("Experiment " + experiment_name + " already exists... setting to active experiment")
281-
self.active_experiment = experiment
282-
print("Active experiment has id " + str(experiment.id))
283-
if reset:
284-
print("Keyword argument \"reset\" was set to True. Overwriting experiment and its associated runs...")
285-
experiment_id = self.active_experiment.experiment_id
286-
associated_runs = self.list_run_infos(experiment_id)
287-
for run in associated_runs:
288-
print("Deleting run with UUID " + run.run_uuid)
289-
manager.delete_run(run.run_uuid)
290-
print("Successfully overwrote experiment")
291-
else:
292-
experiment_id = super(MLManager, self).create_experiment(experiment_name)
293-
print("Created experiment w/ id=" + str(experiment_id))
294-
self.set_active_experiment(experiment_id)
295-
296265

297266
def set_active_experiment(self, experiment_name):
298267
"""
299268
Set the active experiment of which all new runs will be created under
300269
Does not apply to already created runs
301-
302270
:param experiment_name: either an integer (experiment id) or a string (experiment name)
303271
"""
304272

@@ -315,7 +283,7 @@ def set_active_run(self, run_id):
315283
"""
316284
self.active_run = self.get_run(run_id)
317285

318-
def start_run(self, tags=None, experiment_id=None):
286+
def start_run(self, tags=None, run_name=None, experiment_id=None, nested=False):
319287
"""
320288
Create a new run in the active experiment and set it to be active
321289
:param tags: a dictionary containing metadata about the current run.
@@ -327,7 +295,7 @@ def start_run(self, tags=None, experiment_id=None):
327295
:param run_name: an optional name for the run to show up in the MLFlow UI
328296
:param experiment_id: if this is specified, the experiment id of this
329297
will override the active run.
330-
298+
:param nester: Controls whether run is nested in parent run. True creates a nest run
331299
"""
332300
if experiment_id:
333301
new_run_exp_id = experiment_id
@@ -336,14 +304,21 @@ def start_run(self, tags=None, experiment_id=None):
336304
new_run_exp_id = self.active_experiment.experiment_id
337305
else:
338306
new_run_exp_id = 0
339-
self.set_active_experiment(new_run_exp_id)
307+
try:
308+
self.set_active_experiment(new_run_exp_id)
309+
except MlflowException:
310+
raise MlflowException("There are no experiements available yet. Please create an experiment before starting a run")
340311

341312
if not tags:
342313
tags = {}
343314

344315
tags['mlflow.user'] = _get_user()
345316

346317
self.active_run = super(MLManager, self).create_run(new_run_exp_id, tags=tags)
318+
if run_name:
319+
manager.set_tag('mlflow.runName',run_name)
320+
print(f'Setting {run_name} to active run')
321+
347322

348323
def get_run(self, run_id):
349324
"""
@@ -396,9 +371,11 @@ def set_tag(self, *args, **kwargs):
396371
@check_active
397372
def set_tags(self, tags):
398373
"""
399-
Log a list of tags in order
374+
Log a list of tags in order or a dictionary of tags
400375
:param params: a list of tuples containing tags mapped to tag values
401376
"""
377+
if isinstance(tags,dict):
378+
tags = list(tags.items())
402379
for tag in tags:
403380
self.set_tag(*tag)
404381

@@ -480,7 +457,6 @@ def _is_spark_model(spark_object):
480457
is a model, it will return True, if it is a
481458
pipeline model is will return False.
482459
Otherwise, it will throw an exception
483-
484460
:param spark_object: (Model) Spark object to check
485461
:return: (bool) whether or not the object is a model
486462
:exception: (Exception) throws an error if it is not either
@@ -535,11 +511,9 @@ def log_pipeline_stages(self, pipeline):
535511
"""
536512
Log the human-friendly names of each stage in
537513
a Spark pipeline.
538-
539514
*Warning*: With a big pipeline, this could result in
540515
a lot of parameters in MLFlow. It is probably best
541516
to log them yourself, so you can ensure useful tracking
542-
543517
:param pipeline: the fitted/unfit pipeline object
544518
"""
545519

@@ -551,7 +525,6 @@ def log_pipeline_stages(self, pipeline):
551525
def _find_first_input_by_output(dictionary, value):
552526
"""
553527
Find the first input column for a given column
554-
555528
:param dictionary: dictionary to search
556529
:param value: column
557530
:return: None if not found, otherwise first column
@@ -566,7 +539,6 @@ def log_feature_transformations(self, unfit_pipeline):
566539
"""
567540
Log the preprocessing transformation sequence
568541
for every feature in the UNFITTED Spark pipeline
569-
570542
:param unfit_pipeline: UNFITTED spark pipeline!!
571543
"""
572544
transformations = defaultdict(lambda: [[], None]) # transformations, outputColumn
@@ -596,7 +568,6 @@ def start_timer(self, timer_name):
596568
Start a given timer with the specified
597569
timer name, which will be logged when the
598570
run is stopped
599-
600571
:param timer_name: the name to call the timer (will appear in MLFlow UI)
601572
"""
602573
self.timer_name = timer_name
@@ -626,10 +597,8 @@ def log_evaluator_metrics(self, splice_evaluator):
626597
"""
627598
Takes an Splice evaluator and logs
628599
all of the associated metrics with it
629-
630600
:param splice_evaluator: a Splice evaluator (from
631601
splicemachine.ml.utilities package in pysplice)
632-
633602
:return: retrieved metrics dict
634603
"""
635604
results = splice_evaluator.get_results('dict')
@@ -705,7 +674,6 @@ def download_artifact(self, name, local_path, run_id=None):
705674
Download the artifact at the given
706675
run id (active default) + name
707676
to the local path
708-
709677
:param name: (str) artifact name to load
710678
(with respect to the run)
711679
:param local_path: (str) local path to download the
@@ -745,7 +713,6 @@ def login_director(self, username, password):
745713
"""
746714
Login to MLmanager Director so we can
747715
submit jobs
748-
749716
:param username: (str) database username
750717
:param password: (str) database password
751718
"""
@@ -756,7 +723,6 @@ def login_director(self, username, password):
756723
def _initiate_job(self, payload, endpoint):
757724
"""
758725
Send a job to the initiation endpoint
759-
760726
:param payload: (dict) JSON payload for POST request
761727
:param endpoint: (str) REST endpoint to target
762728
:return: (str) Response text from request
@@ -868,7 +834,6 @@ def deploy_azure(self, endpoint_name, resource_group, workspace, run_id=None, re
868834
cpu_cores=0.1, allocated_ram=0.5, model_name=None):
869835
"""
870836
Deploy a given run to AzureML.
871-
872837
:param endpoint_name: (str) the name of the endpoint in AzureML when deployed to
873838
Azure Container Services. Must be unique.
874839
:param resource_group: (str) Azure Resource Group for model. Automatically created if

0 commit comments

Comments
 (0)