Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit 5ddcf9b

Browse files
author
Ben Epstein
authored
Dbaas 4951 (#98)
* cleanup of api * add context (primary) key to describe feature sets * optional verbose to print sql in create_training_context * added get_feature_dataset * comments * old code * i hate upppercase * commment * sql format * i still hate uppercase * null tx * sql format * sql format * docs * docs * docs * docs * docs * docs * docs * docs * docs * docs * docs * docs * docs * docs * docs * verbose * docs * docs * column ordering * feature param cleanup * training context features * removed clean_df * to_lower * docs * better logic * better logic * label column validation * refactor TrainingContext -> TrainingView, Feature Set Context Key -> Feature Set Join Key * missed one * exclude 2 more funcs * docs * as list * missed some more * hashable * pep * docs * docs * handleinvalid keep * feature_vector_sql * get-features_by_name requires names * exclude members * return Feature, docs fix * history for get_training_set without TrainingView * removed clean_df * missing collect * curretn values * 2 froms * add pk cols * join with itself * better line * merge master * remove include_insert * myles as codeowner
2 parents 7660e39 + b5b679a commit 5ddcf9b

File tree

6 files changed

+245
-144
lines changed

6 files changed

+245
-144
lines changed

CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#Global-Reviewers
22

33
#* @splicemachine/splice-cloudops
4-
* @bklo94 @Ben-Epstein @edriggers @jhoule-splice @jramineni @njnygaard @patricksplice @splicemaahs @troysplice @abaveja313
4+
* @bklo94 @Ben-Epstein @edriggers @jramineni @myles-novick @splicemaahs @troysplice @abaveja313

splicemachine/features/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ class SQL:
9696
FROM {FEATURE_STORE_SCHEMA}.feature_set_key GROUP BY 1
9797
) p
9898
ON fset.feature_set_id=p.feature_set_id
99-
where fset.feature_set_id in (select feature_set_id from {FEATURE_STORE_SCHEMA}.feature where name in {{names}} )
99+
WHERE fset.feature_set_id in (select feature_set_id from {FEATURE_STORE_SCHEMA}.feature where name in {{names}} )
100+
ORDER BY schema_name, table_name
100101
"""
101102

102103
get_all_features = f"SELECT NAME FROM {FEATURE_STORE_SCHEMA}.feature WHERE Name='{{name}}'"

splicemachine/features/feature_set.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from splicemachine.features import Feature
22
from .constants import SQL, Columns
3-
from .utils import clean_df
43
from splicemachine.spark import PySpliceContext
54
from typing import List, Dict
65

@@ -34,8 +33,8 @@ def get_features(self) -> List[Feature]:
3433
"""
3534
features = []
3635
if self.feature_set_id:
37-
features_df = self.splice_ctx.df(SQL.get_features_in_feature_set.format(feature_set_id=self.feature_set_id))
38-
features_df = clean_df(features_df, Columns.feature).collect()
36+
features_df = self.splice_ctx.df(SQL.get_features_in_feature_set.format(feature_set_id=self.feature_set_id),
37+
to_lower=True).collect()
3938
for f in features_df:
4039
f = f.asDict()
4140
features.append(Feature(**f))
@@ -134,7 +133,8 @@ def deploy(self, verbose=False):
134133

135134
def __eq__(self, other):
136135
if isinstance(other, FeatureSet):
137-
return self.table_name == other.table_name and self.schema_name == other.schema_name
136+
return self.table_name.lower() == other.table_name.lower() and \
137+
self.schema_name.lower() == other.schema_name.lower()
138138
return False
139139

140140
def __repr__(self):

splicemachine/features/feature_store.py

Lines changed: 82 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from splicemachine.spark import PySpliceContext
1717
from splicemachine.features import Feature, FeatureSet
1818
from .training_set import TrainingSet
19-
from .utils import dict_to_lower
19+
from .utils import (dict_to_lower, _generate_training_set_history_sql,
20+
_generate_training_set_sql, _create_temp_training_view)
2021
from .constants import SQL, FeatureType
2122
from .training_view import TrainingView
2223

@@ -64,75 +65,6 @@ def get_feature_sets(self, feature_set_ids: List[int] = None, _filter: Dict[str,
6465
feature_sets.append(FeatureSet(splice_ctx=self.splice_ctx, **d))
6566
return feature_sets
6667

67-
def get_training_set(self, features: Union[List[Feature], List[str]]) -> SparkDF:
68-
"""
69-
Gets a set of feature values across feature sets that is not time dependent (ie for non time series clustering).
70-
This feature dataset will be treated and tracked implicitly the same way a training_dataset is tracked from
71-
:py:meth:`features.FeatureStore.get_training_set` . The dataset's metadata and features used will be tracked in mlflow automatically (see
72-
get_training_set for more details).
73-
74-
:param features: List of Features or strings of feature names
75-
76-
:NOTE:
77-
.. code-block:: text
78-
79-
The Features Sets which the list of Features come from must have common join keys,
80-
otherwise the function will fail. If there is no common join key, it is recommended to
81-
create a Training View to specify the join conditions.
82-
83-
:return: Spark DF
84-
"""
85-
features = self._process_features(features)
86-
87-
sql = SQL.get_feature_set_join_keys.format(names=tuple([f.name for f in features]))
88-
fset_keys: pd.DataFrame = self.splice_ctx.df(sql).toPandas()
89-
# Get max number of pk (join) columns from all feature sets
90-
fset_keys['PK_COLUMNS_COUNT'] = fset_keys['PK_COLUMNS'].apply(lambda x: len(x.split('|')))
91-
# Get "anchor" feature set. The one we will use to try to join to all others
92-
ind = fset_keys['PK_COLUMNS_COUNT'].idxmax()
93-
anchor_series = fset_keys.iloc[ind]
94-
# Remove that from the list
95-
fset_keys.drop(index=ind, inplace=True)
96-
all_pk_cols = anchor_series.PK_COLUMNS.split('|')
97-
# For each feature set, assert that all join keys exist in our "anchor" feature set
98-
fset_keys['can_join'] = fset_keys['PK_COLUMNS'].map(lambda x: set(x.split('|')).issubset(all_pk_cols))
99-
if not fset_keys['can_join'].all():
100-
bad_feature_set_ids = [t.FEATURE_SET_ID for _, t in fset_keys[fset_keys['can_join'] != True].iterrows()]
101-
bad_features = [f.name for f in features if f.feature_set_id in bad_feature_set_ids]
102-
raise SpliceMachineException(f"The provided features do not have a common join key."
103-
f"Remove features {bad_features} from your request")
104-
105-
# SELECT clause
106-
sql = 'SELECT '
107-
108-
sql += ','.join([f'fset{feature.feature_set_id}.{feature.name}' for feature in features])
109-
110-
alias = f'fset{anchor_series.FEATURE_SET_ID}' # We use this a lot for joins
111-
sql += f'\nFROM {anchor_series.SCHEMA_NAME}.{anchor_series.TABLE_NAME} {alias} '
112-
113-
# JOIN clause
114-
for _, fset in fset_keys.iterrows():
115-
# Join Feature Set
116-
sql += f'\nLEFT OUTER JOIN {fset.SCHEMA_NAME}.{fset.TABLE_NAME} fset{fset.FEATURE_SET_ID} \n\tON '
117-
for ind, pkcol in enumerate(fset.PK_COLUMNS.split('|')):
118-
if ind > 0: sql += ' AND ' # In case of multiple columns
119-
sql += f'fset{fset.FEATURE_SET_ID}.{pkcol}={alias}.{pkcol}'
120-
121-
# Link this to mlflow for model deployment
122-
# Here we create a null training view and pass it into the training set. We do this because this special kind
123-
# of training set isn't standard. It's not based on a training view, on primary key columns, a label column,
124-
# or a timestamp column . This is simply a joined set of features from different feature sets.
125-
# But we still want to track this in mlflow as a user may build and deploy a model based on this. So we pass in
126-
# a null training view that can be tracked with a "name" (although the name is None). This is a likely case
127-
# for (non time based) clustering use cases.
128-
null_tx = TrainingView(pk_columns=[], ts_column=None, label_column=None, view_sql=None, name=None,
129-
description=None)
130-
ts = TrainingSet(training_view=null_tx, features=features)
131-
if hasattr(self, 'mlflow_ctx'):
132-
self.mlflow_ctx._active_training_set: TrainingSet = ts
133-
ts._register_metadata(self.mlflow_ctx)
134-
return self.splice_ctx.df(sql)
135-
13668
def remove_training_view(self, override=False):
13769
"""
13870
Note: This function is not yet implemented.
@@ -228,8 +160,8 @@ def _validate_feature_vector_keys(self, join_key_values, feature_sets) -> None:
228160
missing_keys = feature_set_key_columns - join_key_values.keys()
229161
assert not missing_keys, f"The following keys were not provided and must be: {missing_keys}"
230162

231-
def get_feature_vector(self, features: List[Union[str,Feature]],
232-
join_key_values: Dict[str,str], return_sql=False) -> Union[str, PandasDF]:
163+
def get_feature_vector(self, features: List[Union[str, Feature]],
164+
join_key_values: Dict[str, str], return_sql=False) -> Union[str, PandasDF]:
233165
"""
234166
Gets a feature vector given a list of Features and primary key values for their corresponding Feature Sets
235167
@@ -246,9 +178,9 @@ def get_feature_vector(self, features: List[Union[str,Feature]],
246178
feature_sets = self.get_feature_sets([f.feature_set_id for f in feats])
247179
self._validate_feature_vector_keys(join_keys, feature_sets)
248180

249-
250181
feature_names = ','.join([f.name for f in feats])
251-
fset_tables = ','.join([f'{fset.schema_name}.{fset.table_name} fset{fset.feature_set_id}' for fset in feature_sets])
182+
fset_tables = ','.join(
183+
[f'{fset.schema_name}.{fset.table_name} fset{fset.feature_set_id}' for fset in feature_sets])
252184
sql = "SELECT {feature_names} FROM {fset_tables} ".format(feature_names=feature_names, fset_tables=fset_tables)
253185

254186
# For each Feature Set, for each primary key in the given feature set, get primary key value from the user provided dictionary
@@ -260,8 +192,7 @@ def get_feature_vector(self, features: List[Union[str,Feature]],
260192

261193
return sql if return_sql else self.splice_ctx.df(sql).toPandas()
262194

263-
def get_feature_vector_sql_from_training_view(self, training_view: str, features: List[Feature],
264-
include_insert: Optional[bool] = True) -> str:
195+
def get_feature_vector_sql_from_training_view(self, training_view: str, features: List[Feature]) -> str:
265196
"""
266197
Returns the parameterized feature retrieval SQL used for online model serving.
267198
@@ -274,25 +205,14 @@ def get_feature_vector_sql_from_training_view(self, training_view: str, features
274205
This function will error if the view SQL is missing a view key required to retrieve the\
275206
desired features
276207
277-
:param include_insert: (Optional[bool]) determines whether insert into model table is included in the SQL statement
278208
:return: (str) the parameterized feature vector SQL
279209
"""
280210

281211
# Get training view information (ctx primary key column(s), ctx primary key inference ts column, )
282212
vid = self.get_training_view_id(training_view)
283213
tctx = self.get_training_views(_filter={'view_id': vid})[0]
284214

285-
# optional INSERT prefix
286-
if (include_insert):
287-
sql = 'INSERT INTO {target_model_table} ('
288-
for pkcol in tctx.pk_columns: # Select primary key column(s)
289-
sql += f'{pkcol}, '
290-
for feature in features:
291-
sql += f'{feature.name}, ' # Collect all features over time
292-
sql = sql.rstrip(', ')
293-
sql += ')\nSELECT '
294-
else:
295-
sql = 'SELECT '
215+
sql = 'SELECT '
296216

297217
# SELECT expressions
298218
for pkcol in tctx.pk_columns: # Select primary key column(s)
@@ -351,6 +271,73 @@ def get_feature_description(self):
351271
# TODO
352272
raise NotImplementedError
353273

274+
def get_training_set(self, features: Union[List[Feature], List[str]], current_values_only: bool = False,
275+
start_time: datetime = None, end_time: datetime = None, return_sql: bool = False) -> SparkDF:
276+
"""
277+
Gets a set of feature values across feature sets that is not time dependent (ie for non time series clustering).
278+
This feature dataset will be treated and tracked implicitly the same way a training_dataset is tracked from
279+
:py:meth:`features.FeatureStore.get_training_set` . The dataset's metadata and features used will be tracked in mlflow automatically (see
280+
get_training_set for more details).
281+
282+
The way point-in-time correctness is guaranteed here is by choosing one of the Feature Sets as the "anchor" dataset.
283+
This means that the points in time that the query is based off of will be the points in time in which the anchor
284+
Feature Set recorded changes. The anchor Feature Set is the Feature Set that contains the superset of all primary key
285+
columns across all Feature Sets from all Features provided. If more than 1 Feature Set has the superset of
286+
all Feature Sets, the Feature Set with the most primary keys is selected. If more than 1 Feature Set has the same
287+
maximum number of primary keys, the Feature Set is chosen by alphabetical order (schema_name, table_name).
288+
289+
:param features: List of Features or strings of feature names
290+
291+
:NOTE:
292+
.. code-block:: text
293+
294+
The Features Sets which the list of Features come from must have common join keys,
295+
otherwise the function will fail. If there is no common join key, it is recommended to
296+
create a Training View to specify the join conditions.
297+
298+
:param current_values_only: If you only want the most recent values of the features, set this to true. Otherwise, all history will be returned. Default False
299+
:param start_time: How far back in history you want Feature values. If not specified (and current_values_only is False), all history will be returned.
300+
This parameter only takes effect if current_values_only is False.
301+
:param end_time: The most recent values for each selected Feature. This will be the cutoff time, such that any Feature values that
302+
were updated after this point in time won't be selected. If not specified (and current_values_only is False),
303+
Feature values up to the moment in time you call the function (now) will be retrieved. This parameter
304+
only takes effect if current_values_only is False.
305+
:return: Spark DF
306+
"""
307+
# Get List[Feature]
308+
features = self._process_features(features)
309+
310+
# Get the Feature Sets
311+
fsets = self.get_feature_sets(list({f.feature_set_id for f in features}))
312+
313+
if current_values_only:
314+
sql = _generate_training_set_sql(features, fsets)
315+
else:
316+
temp_vw = _create_temp_training_view(features, fsets)
317+
sql = _generate_training_set_history_sql(temp_vw, features, fsets, start_time=start_time, end_time=end_time)
318+
319+
320+
# Here we create a null training view and pass it into the training set. We do this because this special kind
321+
# of training set isn't standard. It's not based on a training view, on primary key columns, a label column,
322+
# or a timestamp column . This is simply a joined set of features from different feature sets.
323+
# But we still want to track this in mlflow as a user may build and deploy a model based on this. So we pass in
324+
# a null training view that can be tracked with a "name" (although the name is None). This is a likely case
325+
# for (non time based) clustering use cases.
326+
null_tvw = TrainingView(pk_columns=[], ts_column=None, label_column=None, view_sql=None, name=None,
327+
description=None)
328+
ts = TrainingSet(training_view=null_tvw, features=features, start_time=start_time, end_time=end_time)
329+
330+
# If the user isn't getting historical values, that means there isn't really a start_time, as the user simply
331+
# wants the most up to date values of each feature. So we set start_time to end_time (which is datetime.today)
332+
# For metadata purposes
333+
if current_values_only:
334+
ts.start_time = ts.end_time
335+
336+
if hasattr(self, 'mlflow_ctx'):
337+
self.mlflow_ctx._active_training_set: TrainingSet = ts
338+
ts._register_metadata(self.mlflow_ctx)
339+
return sql if return_sql else self.splice_ctx.df(sql)
340+
354341
def get_training_set_from_view(self, training_view: str, features: Union[List[Feature], List[str]] = None,
355342
start_time: Optional[datetime] = None, end_time: Optional[datetime] = None,
356343
return_sql: bool = False) -> SparkDF or str:
@@ -394,61 +381,21 @@ def get_training_set_from_view(self, training_view: str, features: Union[List[Fe
394381
:return: Optional[SparkDF, str] The Spark dataframe of the training set or the SQL that is used to generate it (for debugging)
395382
"""
396383

384+
# Get features as list of Features
397385
features = self._process_features(features) if features else self.get_training_view_features(training_view)
398-
# DB-9556 loss of column names on complex sql for NSDS
399-
cols = []
400-
401-
# Get training view information (view primary key column(s), inference ts column, )
402-
tctx = self.get_training_view(training_view)
403-
# SELECT clause
404-
sql = 'SELECT '
405-
for pkcol in tctx.pk_columns: # Select primary key column(s)
406-
sql += f'\n\tctx.{pkcol},'
407-
cols.append(pkcol)
408-
409-
sql += f'\n\tctx.{tctx.ts_column}, ' # Select timestamp column
410-
cols.append(tctx.ts_column)
411386

412-
# TODO: ensure these features exist and fail gracefully if not
413-
for feature in features:
414-
sql += f'\n\tCOALESCE(fset{feature.feature_set_id}.{feature.name},fset{feature.feature_set_id}h.{feature.name}) {feature.name},' # Collect all features over time
415-
cols.append(feature.name)
416-
417-
sql = sql + f'\n\tctx.{tctx.label_column}' if tctx.label_column else sql.rstrip(
418-
',') # Select the optional label col
419-
if tctx.label_column: cols.append(tctx.label_column)
420-
421-
# FROM clause
422-
sql += f'\nFROM ({tctx.view_sql}) ctx '
423-
424-
# JOIN clause
387+
# Get List of necessary Feature Sets
425388
feature_set_ids = list({f.feature_set_id for f in features}) # Distinct set of IDs
426389
feature_sets = self.get_feature_sets(feature_set_ids)
427-
for fset in feature_sets:
428-
# Join Feature Set
429-
sql += f'\nLEFT OUTER JOIN {fset.schema_name}.{fset.table_name} fset{fset.feature_set_id} \n\tON '
430-
for pkcol in fset.pk_columns:
431-
sql += f'fset{fset.feature_set_id}.{pkcol}=ctx.{pkcol} AND '
432-
sql += f' ctx.{tctx.ts_column} >= fset{fset.feature_set_id}.LAST_UPDATE_TS '
433390

434-
# Join Feature Set History
435-
sql += f'\nLEFT OUTER JOIN {fset.schema_name}.{fset.table_name}_history fset{fset.feature_set_id}h \n\tON '
436-
for pkcol in fset.pk_columns:
437-
sql += f' fset{fset.feature_set_id}h.{pkcol}=ctx.{pkcol} AND '
438-
sql += f' ctx.{tctx.ts_column} >= fset{fset.feature_set_id}h.ASOF_TS AND ctx.{tctx.ts_column} < fset{fset.feature_set_id}h.UNTIL_TS'
439-
440-
# WHERE clause on optional start and end times
441-
if start_time or end_time:
442-
sql += '\nWHERE '
443-
if start_time:
444-
sql += f"\n\tctx.{tctx.ts_column} >= '{str(start_time)}' AND"
445-
if end_time:
446-
sql += f"\n\tctx.{tctx.ts_column} <= '{str(end_time)}'"
447-
sql = sql.rstrip('AND')
391+
# Get training view information (view primary key column(s), inference ts column, )
392+
tvw = self.get_training_view(training_view)
393+
# Generate the SQL needed to create the dataset
394+
sql = _generate_training_set_history_sql(tvw, features, feature_sets, start_time=start_time, end_time=end_time)
448395

449396
# Link this to mlflow for model deployment
450397
if hasattr(self, 'mlflow_ctx') and not return_sql:
451-
ts = TrainingSet(training_view=tctx, features=features,
398+
ts = TrainingSet(training_view=tvw, features=features,
452399
start_time=start_time, end_time=end_time)
453400
self.mlflow_ctx._active_training_set: TrainingSet = ts
454401
ts._register_metadata(self.mlflow_ctx)

splicemachine/features/training_set.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ def __init__(self,
2323
self.end_time = end_time or datetime.today()
2424

2525
def _register_metadata(self, mlflow_ctx):
26+
"""
27+
Registers training set with mlflow if the user has registered the feature store in their mlflow session,
28+
and has called either get_training_set or get_training_set_from_view before or during an mlflow run
29+
30+
:param mlflow_ctx: the mlflow context
31+
:return: None
32+
"""
2633
if mlflow_ctx.active_run():
2734
print("There is an active mlflow run, your training set will be logged to that run.")
2835
mlflow_ctx.lp("splice.feature_store.training_set",self.training_view.name)

0 commit comments

Comments
 (0)