Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit ea7074f

Browse files
author
Myles Novick
authored
Merge branch 'master' into DBAAS-5102
2 parents 674a0a7 + ec2838e commit ea7074f

File tree

4 files changed

+38
-11
lines changed

4 files changed

+38
-11
lines changed

splicemachine/features/feature_store.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ def get_training_set_from_view(self, training_view: str, features: Union[List[Fe
293293
r = make_request(self._FS_URL, Endpoints.TRAINING_SET_FROM_VIEW, RequestType.POST, self._basic_auth, { "view": training_view },
294294
{ "features": features, "start_time": start_time, "end_time": end_time })
295295
sql = r["sql"]
296-
tvw = r["training_view"]
296+
tvw = TrainingView(**r["training_view"])
297+
features = [Feature(**f) for f in r["features"]]
297298

298299
# Link this to mlflow for model deployment
299300
if self.mlflow_ctx and not return_sql:
@@ -528,16 +529,19 @@ def get_training_set_from_deployment(self, schema_name: str, table_name: str):
528529

529530
r = make_request(self._FS_URL, Endpoints.TRAINING_SET_FROM_DEPLOYMENT, RequestType.GET, self._basic_auth,
530531
{ "schema": schema_name, "table": table_name })
531-
metadata = r['metadata']
532532

533+
metadata = r['metadata']
533534
sql = r['sql']
534-
features = r['features']
535+
535536
tv_name = metadata['name']
536537
start_time = metadata['training_set_start_ts']
537538
end_time = metadata['training_set_end_ts']
538539

540+
tv = TrainingView(**r['training_view']) if 'training_view' in r else None
541+
features = [Feature(**f) for f in r['features']]
542+
539543
if self.mlflow_ctx:
540-
self.link_training_set_to_mlflow(features, start_time, end_time, tv_name)
544+
self.link_training_set_to_mlflow(features, start_time, end_time, tv)
541545
return self.splice_ctx.df(sql)
542546

543547
def remove_feature(self, name: str):
@@ -551,14 +555,23 @@ def remove_feature(self, name: str):
551555
"""
552556
make_request(self._FS_URL, Endpoints.FEATURES, RequestType.DELETE, self._basic_auth, { "name": name })
553557

554-
def get_training_set_features(self, training_set: str = None):
558+
def get_deployments(self, schema_name: str = None, table_name: str = None, training_set: str = None):
555559
"""
556-
Returns a list of all features from an available Training Set, as well as details about that Training Set
560+
Returns a list of all (or specified) available deployments
557561
:param schema_name: model schema name
558562
:param table_name: model table name
559563
:param training_set: training set name
560564
:return: List[Deployment] the list of Deployments as dicts
561565
"""
566+
return make_request(self._FS_URL, Endpoints.DEPLOYMENTS, RequestType.GET, self._basic_auth,
567+
{ 'schema': schema_name, 'table': table_name, 'name': training_set })
568+
569+
def get_training_set_features(self, training_set: str = None):
570+
"""
571+
Returns a list of all features from an available Training Set, as well as details about that Training Set
572+
:param training_set: training set name
573+
:return: TrainingSet as dict
574+
"""
562575
r = make_request(self._FS_URL, Endpoints.TRAINING_SET_FEATURES, RequestType.GET, self._basic_auth,
563576
{ 'name': training_set })
564577
r['features'] = [Feature(**f) for f in r['features']]
@@ -798,6 +811,7 @@ def link_training_set_to_mlflow(self, features: Union[List[Feature], List[str]],
798811

799812
self.mlflow_ctx._active_training_set: TrainingSet = ts
800813
ts._register_metadata(self.mlflow_ctx)
814+
801815

802816
def set_feature_store_url(self, url: str):
803817
self._FS_URL = url

splicemachine/features/training_set.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def _register_metadata(self, mlflow_ctx):
4848
"Training Set was logged to the current active run. If you call "
4949
"fs.get_training_set or fs.get_training_set_from_view before starting an "
5050
"mlflow run, all following runs will assume that Training Set to be the "
51-
"active Training Set, and will log the Training Set as metadata. For more "
52-
"information, refer to the documentation. If you'd like to use a new "
53-
"Training Set, end the current run, call one of the mentioned functions, "
54-
"and start your new run.") from None
51+
"active Training Set (until the next call to either of those functions), "
52+
"and will log the Training Set as metadata. For more information, "
53+
"refer to the documentation. If you'd like to use a new Training Set, "
54+
"end the current run, call one of the mentioned functions, and start "
55+
"your new run. Or, call mlflow.remove_active_training_set()") from None

splicemachine/features/utils/http_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class Endpoints:
2828
"""
2929
Enum for Feature Store Endpoints
3030
"""
31+
DEPLOYMENTS: str = "deployments"
3132
FEATURES: str = "features"
3233
FEATURE_SETS: str = "feature-sets"
3334
FEATURE_SET_DESCRIPTIONS: str = "feature-set-descriptions"

splicemachine/mlflow_support/mlflow_support.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,16 @@ def _start_run(run_id=None, tags=None, experiment_id=None, run_name=None, nested
438438

439439
return SpliceActiveRun(active_run)
440440

441+
@_mlflow_patch('remove_active_training_set')
442+
def _remove_active_training_set():
443+
"""
444+
Removes the active training set from mlflow. This function deletes mlflows active training set (retrieved from
445+
the feature store), which will in turn stop the automated logging of features to the active mlflow run. To recreate
446+
an active training set, call fs.get_training_set or fs.get_training_set_from_view in the Feature Store.
447+
"""
448+
if hasattr(mlflow,'_active_training_set'):
449+
del mlflow._active_training_set
450+
441451

442452
@_mlflow_patch('log_pipeline_stages')
443453
def _log_pipeline_stages(pipeline):
@@ -1002,7 +1012,8 @@ def apply_patches():
10021012
targets = [_register_feature_store, _register_splice_context, _lp, _lm, _timer, _log_artifact, _log_feature_transformations,
10031013
_log_model_params, _log_pipeline_stages, _log_model, _load_model, _download_artifact,
10041014
_start_run, _current_run_id, _current_exp_id, _deploy_aws, _deploy_azure, _deploy_db, _login_director,
1005-
_get_run_ids_by_name, _get_deployed_models, _deploy_kubernetes, _fetch_logs, _watch_job, _end_run, _set_mlflow_uri]
1015+
_get_run_ids_by_name, _get_deployed_models, _deploy_kubernetes, _fetch_logs, _watch_job, _end_run,
1016+
_set_mlflow_uri, _remove_active_training_set]
10061017

10071018
for target in targets:
10081019
gorilla.apply(gorilla.Patch(mlflow, target.__name__.lstrip('_'), target, settings=_GORILLA_SETTINGS))

0 commit comments

Comments
 (0)