Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit 1cbfb84

Browse files
author
Ben Epstein
authored
Merge branch 'master' into DBAAS-5131
2 parents 9f90fec + 0d35fa0 commit 1cbfb84

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

splicemachine/features/feature_store.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ def get_feature_description(self):
190190
raise NotImplementedError
191191

192192
def get_training_set(self, features: Union[List[Feature], List[str]], current_values_only: bool = False,
193-
start_time: datetime = None, end_time: datetime = None, return_sql: bool = False) -> SparkDF:
193+
start_time: datetime = None, end_time: datetime = None, label: str = None, return_pk_cols: bool = False,
194+
return_ts_col: bool = False, return_sql: bool = False) -> SparkDF or str:
194195
"""
195196
Gets a set of feature values across feature sets that is not time dependent (ie for non time series clustering).
196197
This feature dataset will be treated and tracked implicitly the same way a training_dataset is tracked from
@@ -220,14 +221,22 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va
220221
were updated after this point in time won't be selected. If not specified (and current_values_only is False),
221222
Feature values up to the moment in time you call the function (now) will be retrieved. This parameter
222223
only takes effect if current_values_only is False.
224+
:param label: An optional label to specify for the training set. If specified, the feature set of that feature
225+
will be used as the "anchor" feature set, meaning all point in time joins will be made to the timestamps of
226+
that feature set. This feature will also be recorded as a "label" feature for this particular training set
227+
(but not others in the future, unless this label is again specified).
228+
:param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
229+
:param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
223230
:return: Spark DF
224231
"""
225232
features = [f if isinstance(f, str) else f.__dict__ for f in features]
226-
r = make_request(self._FS_URL, Endpoints.TRAINING_SETS, RequestType.POST, self._basic_auth,
227-
{ "current": current_values_only },
228-
{ "features": features, "start_time": start_time, "end_time": end_time })
233+
r = make_request(self._FS_URL, Endpoints.TRAINING_SETS, RequestType.POST, self._basic_auth,
234+
{ "current": current_values_only, "label": label, "pks": return_pk_cols, "ts": return_ts_col },
235+
{ "features": features, "start_time": start_time, "end_time": end_time })
229236
create_time = r['metadata']['training_set_create_ts']
230-
237+
sql = r['sql']
238+
tvw = TrainingView(**r['training_view'])
239+
features = [Feature(**f) for f in r['features']]
231240
# Here we create a null training view and pass it into the training set. We do this because this special kind
232241
# of training set isn't standard. It's not based on a training view, on primary key columns, a label column,
233242
# or a timestamp column . This is simply a joined set of features from different feature sets.
@@ -242,12 +251,12 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va
242251
# wants the most up to date values of each feature. So we set start_time to end_time (which is datetime.today)
243252

244253
if self.mlflow_ctx and not return_sql:
245-
self.link_training_set_to_mlflow(features, create_time, start_time, end_time)
246-
return r if return_sql else self.splice_ctx.df(r)
254+
self.link_training_set_to_mlflow(features, create_time, start_time, end_time, tvw)
255+
return sql if return_sql else self.splice_ctx.df(sql)
247256

248257
def get_training_set_from_view(self, training_view: str, features: Union[List[Feature], List[str]] = None,
249258
start_time: Optional[datetime] = None, end_time: Optional[datetime] = None,
250-
return_sql: bool = False) -> SparkDF or str:
259+
return_pk_cols: bool = False, return_ts_col: bool = False, return_sql: bool = False) -> SparkDF or str:
251260
"""
252261
Returns the training set as a Spark Dataframe from a Training View. When a user calls this function (assuming they have registered
253262
the feature store with mlflow using :py:meth:`~mlflow.register_feature_store` )
@@ -284,13 +293,16 @@ def get_training_set_from_view(self, training_view: str, features: Union[List[Fe
284293
285294
If end_time is None, query will get most recently available data
286295
296+
:param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
297+
:param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
287298
:param return_sql: (Optional[bool]) Return the SQL statement (str) instead of the Spark DF. Defaults False
288299
:return: Optional[SparkDF, str] The Spark dataframe of the training set or the SQL that is used to generate it (for debugging)
289300
"""
290301

291302
# # Generate the SQL needed to create the dataset
292303
features = [f if isinstance(f, str) else f.__dict__ for f in features] if features else None
293-
r = make_request(self._FS_URL, Endpoints.TRAINING_SET_FROM_VIEW, RequestType.POST, self._basic_auth, { "view": training_view },
304+
r = make_request(self._FS_URL, Endpoints.TRAINING_SET_FROM_VIEW, RequestType.POST, self._basic_auth,
305+
{ "view": training_view, "pks": return_pk_cols, "ts": return_ts_col },
294306
{ "features": features, "start_time": start_time, "end_time": end_time })
295307
sql = r["sql"]
296308
tvw = TrainingView(**r["training_view"])
@@ -518,19 +530,26 @@ def _training_view_describe(self, tcx: TrainingView, feats: List[Feature]):
518530
def set_feature_description(self):
519531
raise NotImplementedError
520532

521-
def get_training_set_from_deployment(self, schema_name: str, table_name: str):
533+
def get_training_set_from_deployment(self, schema_name: str, table_name: str, label: str = None,
534+
return_pk_cols: bool = False, return_ts_col: bool = False):
522535
"""
523536
Reads Feature Store metadata to rebuild orginal training data set used for the given deployed model.
524537
:param schema_name: model schema name
525538
:param table_name: model table name
539+
:param label: An optional label to specify for the training set. If specified, the feature set of that feature
540+
will be used as the "anchor" feature set, meaning all point in time joins will be made to the timestamps of
541+
that feature set. This feature will also be recorded as a "label" feature for this particular training set
542+
(but not others in the future, unless this label is again specified).
543+
:param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
544+
:param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
526545
:return:
527546
"""
528547
# database stores object names in upper case
529548
schema_name = schema_name.upper()
530549
table_name = table_name.upper()
531550

532551
r = make_request(self._FS_URL, Endpoints.TRAINING_SET_FROM_DEPLOYMENT, RequestType.GET, self._basic_auth,
533-
{ "schema": schema_name, "table": table_name })
552+
{ "schema": schema_name, "table": table_name, "label": label, "pks": return_pk_cols, "ts": return_ts_col})
534553

535554
metadata = r['metadata']
536555
sql = r['sql']

splicemachine/features/training_set.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def _register_metadata(self, mlflow_ctx):
4141
mlflow_ctx.lp("splice.feature_store.training_set_end_time",str(self.end_time))
4242
mlflow_ctx.lp("splice.feature_store.training_set_create_time",str(self.create_time))
4343
mlflow_ctx.lp("splice.feature_store.training_set_num_features", len(self.features))
44+
mlflow_ctx.lp("splice.feature_store.training_set_label", self.training_view.label_column)
4445
for i,f in enumerate(self.features):
4546
mlflow_ctx.lp(f'splice.feature_store.training_set_feature_{i}',f.name)
4647
except:

0 commit comments

Comments
 (0)