Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit 0d35fa0

Browse files
author
Myles Novick
authored
Merge pull request #125 from splicemachine/DBAAS-5007
DBAAS-5007: Allow label in get_training_set without a training view
2 parents 53ab0cb + 6a1d073 commit 0d35fa0

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

splicemachine/features/feature_store.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def get_feature_description(self):
192192
raise NotImplementedError
193193

194194
def get_training_set(self, features: Union[List[Feature], List[str]], current_values_only: bool = False,
195-
start_time: datetime = None, end_time: datetime = None, return_sql: bool = False) -> SparkDF:
195+
start_time: datetime = None, end_time: datetime = None, label: str = None, return_pk_cols: bool = False,
196+
return_ts_col: bool = False, return_sql: bool = False) -> SparkDF or str:
196197
"""
197198
Gets a set of feature values across feature sets that is not time dependent (ie for non time series clustering).
198199
This feature dataset will be treated and tracked implicitly the same way a training_dataset is tracked from
@@ -222,13 +223,22 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va
222223
were updated after this point in time won't be selected. If not specified (and current_values_only is False),
223224
Feature values up to the moment in time you call the function (now) will be retrieved. This parameter
224225
only takes effect if current_values_only is False.
226+
:param label: An optional label to specify for the training set. If specified, the feature set of that feature
227+
will be used as the "anchor" feature set, meaning all point in time joins will be made to the timestamps of
228+
that feature set. This feature will also be recorded as a "label" feature for this particular training set
229+
(but not others in the future, unless this label is again specified).
230+
:param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
231+
:param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
225232
:return: Spark DF
226233
"""
227234
features = [f if isinstance(f, str) else f.__dict__ for f in features]
228-
r = make_request(self._FS_URL, Endpoints.TRAINING_SETS, RequestType.POST, self._basic_auth, { "current": current_values_only },
235+
r = make_request(self._FS_URL, Endpoints.TRAINING_SETS, RequestType.POST, self._basic_auth,
236+
{ "current": current_values_only, "label": label, "pks": return_pk_cols, "ts": return_ts_col },
229237
{ "features": features, "start_time": start_time, "end_time": end_time })
230238
create_time = r['metadata']['training_set_create_ts']
231-
239+
sql = r['sql']
240+
tvw = TrainingView(**r['training_view'])
241+
features = [Feature(**f) for f in r['features']]
232242
# Here we create a null training view and pass it into the training set. We do this because this special kind
233243
# of training set isn't standard. It's not based on a training view, on primary key columns, a label column,
234244
# or a timestamp column . This is simply a joined set of features from different feature sets.
@@ -243,12 +253,12 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va
243253
# wants the most up to date values of each feature. So we set start_time to end_time (which is datetime.today)
244254

245255
if self.mlflow_ctx and not return_sql:
246-
self.link_training_set_to_mlflow(features, create_time, start_time, end_time)
247-
return r if return_sql else self.splice_ctx.df(r)
256+
self.link_training_set_to_mlflow(features, create_time, start_time, end_time, tvw)
257+
return sql if return_sql else self.splice_ctx.df(sql)
248258

249259
def get_training_set_from_view(self, training_view: str, features: Union[List[Feature], List[str]] = None,
250260
start_time: Optional[datetime] = None, end_time: Optional[datetime] = None,
251-
return_sql: bool = False) -> SparkDF or str:
261+
return_pk_cols: bool = False, return_ts_col: bool = False, return_sql: bool = False) -> SparkDF or str:
252262
"""
253263
Returns the training set as a Spark Dataframe from a Training View. When a user calls this function (assuming they have registered
254264
the feature store with mlflow using :py:meth:`~mlflow.register_feature_store` )
@@ -285,13 +295,16 @@ def get_training_set_from_view(self, training_view: str, features: Union[List[Fe
285295
286296
If end_time is None, query will get most recently available data
287297
298+
:param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
299+
:param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
288300
:param return_sql: (Optional[bool]) Return the SQL statement (str) instead of the Spark DF. Defaults False
289301
:return: Optional[SparkDF, str] The Spark dataframe of the training set or the SQL that is used to generate it (for debugging)
290302
"""
291303

292304
# # Generate the SQL needed to create the dataset
293305
features = [f if isinstance(f, str) else f.__dict__ for f in features] if features else None
294-
r = make_request(self._FS_URL, Endpoints.TRAINING_SET_FROM_VIEW, RequestType.POST, self._basic_auth, { "view": training_view },
306+
r = make_request(self._FS_URL, Endpoints.TRAINING_SET_FROM_VIEW, RequestType.POST, self._basic_auth,
307+
{ "view": training_view, "pks": return_pk_cols, "ts": return_ts_col },
295308
{ "features": features, "start_time": start_time, "end_time": end_time })
296309
sql = r["sql"]
297310
tvw = TrainingView(**r["training_view"])
@@ -519,19 +532,26 @@ def _training_view_describe(self, tcx: TrainingView, feats: List[Feature]):
519532
def set_feature_description(self):
520533
raise NotImplementedError
521534

522-
def get_training_set_from_deployment(self, schema_name: str, table_name: str):
535+
def get_training_set_from_deployment(self, schema_name: str, table_name: str, label: str = None,
536+
return_pk_cols: bool = False, return_ts_col: bool = False):
523537
"""
524538
Reads Feature Store metadata to rebuild orginal training data set used for the given deployed model.
525539
:param schema_name: model schema name
526540
:param table_name: model table name
541+
:param label: An optional label to specify for the training set. If specified, the feature set of that feature
542+
will be used as the "anchor" feature set, meaning all point in time joins will be made to the timestamps of
543+
that feature set. This feature will also be recorded as a "label" feature for this particular training set
544+
(but not others in the future, unless this label is again specified).
545+
:param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
546+
:param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
527547
:return:
528548
"""
529549
# database stores object names in upper case
530550
schema_name = schema_name.upper()
531551
table_name = table_name.upper()
532552

533553
r = make_request(self._FS_URL, Endpoints.TRAINING_SET_FROM_DEPLOYMENT, RequestType.GET, self._basic_auth,
534-
{ "schema": schema_name, "table": table_name })
554+
{ "schema": schema_name, "table": table_name, "label": label, "pks": return_pk_cols, "ts": return_ts_col})
535555

536556
metadata = r['metadata']
537557
sql = r['sql']

splicemachine/features/training_set.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def _register_metadata(self, mlflow_ctx):
4141
mlflow_ctx.lp("splice.feature_store.training_set_end_time",str(self.end_time))
4242
mlflow_ctx.lp("splice.feature_store.training_set_create_time",str(self.create_time))
4343
mlflow_ctx.lp("splice.feature_store.training_set_num_features", len(self.features))
44+
mlflow_ctx.lp("splice.feature_store.training_set_label", self.training_view.label_column)
4445
for i,f in enumerate(self.features):
4546
mlflow_ctx.lp(f'splice.feature_store.training_set_feature_{i}',f.name)
4647
except:

0 commit comments

Comments
 (0)