Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit 90f874f

Browse files
author
Ben Epstein
authored
Merge pull request #99 from splicemachine/DBAAS-4971
Dbaas 4971
2 parents b063a9b + bfd2786 commit 90f874f

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

splicemachine/features/feature_store.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
class FeatureStore:
2626
def __init__(self, splice_ctx: PySpliceContext) -> None:
2727
self.splice_ctx = splice_ctx
28+
self.mlflow_ctx = None
2829
self.feature_sets = [] # Cache of newly created feature sets
2930

3031
def register_splice_context(self, splice_ctx: PySpliceContext) -> None:
@@ -166,7 +167,7 @@ def get_feature_vector(self, features: List[Union[str, Feature]],
166167
Gets a feature vector given a list of Features and primary key values for their corresponding Feature Sets
167168
168169
:param features: List of str Feature names or Features
169-
:param join_key_values: (dict) join key vals to get the proper Feature values formatted as {join_key_column_name: join_key_value}
170+
:param join_key_values: (dict) join key values to get the proper Feature values formatted as {join_key_column_name: join_key_value}
170171
:param return_sql: Whether to return the SQL needed to get the vector or the values themselves. Default False
171172
:return: Pandas Dataframe or str (SQL statement)
172173
"""
@@ -333,7 +334,7 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va
333334
if current_values_only:
334335
ts.start_time = ts.end_time
335336

336-
if hasattr(self, 'mlflow_ctx'):
337+
if self.mlflow_ctx and not return_sql:
337338
self.mlflow_ctx._active_training_set: TrainingSet = ts
338339
ts._register_metadata(self.mlflow_ctx)
339340
return sql if return_sql else self.splice_ctx.df(sql)
@@ -394,7 +395,7 @@ def get_training_set_from_view(self, training_view: str, features: Union[List[Fe
394395
sql = _generate_training_set_history_sql(tvw, features, feature_sets, start_time=start_time, end_time=end_time)
395396

396397
# Link this to mlflow for model deployment
397-
if hasattr(self, 'mlflow_ctx') and not return_sql:
398+
if self.mlflow_ctx and not return_sql:
398399
ts = TrainingSet(training_view=tvw, features=features,
399400
start_time=start_time, end_time=end_time)
400401
self.mlflow_ctx._active_training_set: TrainingSet = ts
@@ -735,12 +736,16 @@ def __log_mlflow_results(self, name, rounds, mlflow_results):
735736
:param name: MLflow run name
736737
:param rounds: Number of rounds of feature elimination that were run
737738
:param mlflow_results: The params / metrics to log
738-
:return:
739739
"""
740-
with self.mlflow_ctx.start_run(run_name=name):
740+
try:
741+
if self.mlflow_ctx.active_run():
742+
self.mlflow_ctx.start_run(run_name=name)
741743
for r in range(rounds):
742744
with self.mlflow_ctx.start_run(run_name=f'Round {r}', nested=True):
743745
self.mlflow_ctx.log_metrics(mlflow_results[r])
746+
finally:
747+
self.mlflow_ctx.end_run()
748+
744749

745750
def __prune_features_for_elimination(self, features) -> List[Feature]:
746751
"""
@@ -814,7 +819,7 @@ def run_feature_elimination(self, df, features: List[Union[str, Feature]], label
814819
round_metrics[row['name']] = row['score']
815820
mlflow_results.append(round_metrics)
816821

817-
if log_mlflow and hasattr(self, 'mlflow_ctx'):
822+
if log_mlflow and self.mlflow_ctx:
818823
run_name = mlflow_run_name or f'feature_elimination_{label}'
819824
self.__log_mlflow_results(run_name, rnd, mlflow_results)
820825

splicemachine/features/training_set.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .feature import Feature
33
from typing import List, Optional
44
from datetime import datetime
5+
from splicemachine import SpliceMachineException
56

67
class TrainingSet:
78
"""
@@ -19,7 +20,7 @@ def __init__(self,
1920
):
2021
self.training_view = training_view
2122
self.features = features
22-
self.start_time = start_time or datetime.min
23+
self.start_time = start_time or datetime(year=1900,month=1,day=1) # Saw problems with spark handling datetime.min
2324
self.end_time = end_time or datetime.today()
2425

2526
def _register_metadata(self, mlflow_ctx):
@@ -32,9 +33,22 @@ def _register_metadata(self, mlflow_ctx):
3233
"""
3334
if mlflow_ctx.active_run():
3435
print("There is an active mlflow run, your training set will be logged to that run.")
35-
mlflow_ctx.lp("splice.feature_store.training_set",self.training_view.name)
36-
mlflow_ctx.lp("splice.feature_store.training_set_start_time",str(self.start_time))
37-
mlflow_ctx.lp("splice.feature_store.training_set_end_time",str(self.end_time))
38-
mlflow_ctx.lp("splice.feature_store.training_set_num_features", len(self.features))
39-
for i,f in enumerate(self.features):
40-
mlflow_ctx.lp(f'splice.feature_store.training_set_feature_{i}',f.name)
36+
try:
37+
mlflow_ctx.lp("splice.feature_store.training_set",self.training_view.name)
38+
mlflow_ctx.lp("splice.feature_store.training_set_start_time",str(self.start_time))
39+
mlflow_ctx.lp("splice.feature_store.training_set_end_time",str(self.end_time))
40+
mlflow_ctx.lp("splice.feature_store.training_set_num_features", len(self.features))
41+
for i,f in enumerate(self.features):
42+
mlflow_ctx.lp(f'splice.feature_store.training_set_feature_{i}',f.name)
43+
except:
44+
raise SpliceMachineException("It looks like your active run already has a Training Set logged to it. "
45+
"You cannot get a new active Training Set during an active run if you "
46+
"already have an active Training Set. If you've called fs.get_training_set "
47+
"or fs.get_training_set_from_view before starting this run, then that "
48+
"Training Set was logged to the current active run. If you call "
49+
"fs.get_training_set or fs.get_training_set_from_view before starting an "
50+
"mlflow run, all following runs will assume that Training Set to be the "
51+
"active Training Set, and will log the Training Set as metadata. For more "
52+
"information, refer to the documentation. If you'd like to use a new "
53+
"Training Set, end the current run, call one of the mentioned functions, "
54+
"and start your new run.") from None

splicemachine/spark/context.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,17 @@ def replaceDataframeSchema(self, dataframe, schema_table_name):
124124

125125
def fileToTable(self, file_path, schema_table_name, primary_keys=None, drop_table=False, **pandas_args):
126126
"""
127-
Load a file from the local filesystem and create a new table (or recreate an existing table), and load the data
128-
from the file into the new table
127+
Load a file from the local filesystem or from a remote location and create a new table
128+
(or recreate an existing table), and load the data from the file into the new table. Any file_path that can be
129+
read by pandas should work here.
129130
130131
:param file_path: The local file to load
131132
:param schema_table_name: The schema.table name
132133
:param primary_keys: List[str] of primary keys for the table. Default None
133134
:param drop_table: Whether or not to drop the table. If this is False and the table already exists, the
134135
function will fail. Default False
135136
:param pandas_args: Extra parameters to be passed into the pd.read_csv function. Any parameters accepted
136-
in pd.read_csv will work here
137+
in pd.read_csv will work here
137138
:return: None
138139
"""
139140
import pandas as pd

0 commit comments

Comments
 (0)