@@ -190,7 +190,8 @@ def get_feature_description(self):
190190 raise NotImplementedError
191191
192192 def get_training_set (self , features : Union [List [Feature ], List [str ]], current_values_only : bool = False ,
193- start_time : datetime = None , end_time : datetime = None , return_sql : bool = False ) -> SparkDF :
193+ start_time : datetime = None , end_time : datetime = None , label : str = None , return_pk_cols : bool = False ,
194+ return_ts_col : bool = False , return_sql : bool = False ) -> SparkDF or str :
194195 """
195196 Gets a set of feature values across feature sets that is not time dependent (ie for non time series clustering).
196197 This feature dataset will be treated and tracked implicitly the same way a training_dataset is tracked from
@@ -220,14 +221,22 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va
220221 were updated after this point in time won't be selected. If not specified (and current_values_only is False),
221222 Feature values up to the moment in time you call the function (now) will be retrieved. This parameter
222223 only takes effect if current_values_only is False.
224+ :param label: An optional label to specify for the training set. If specified, the feature set of that feature
225+ will be used as the "anchor" feature set, meaning all point in time joins will be made to the timestamps of
226+ that feature set. This feature will also be recorded as a "label" feature for this particular training set
227+ (but not others in the future, unless this label is again specified).
228+ :param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
229+ :param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
223230 :return: Spark DF
224231 """
225232 features = [f if isinstance (f , str ) else f .__dict__ for f in features ]
226- r = make_request (self ._FS_URL , Endpoints .TRAINING_SETS , RequestType .POST , self ._basic_auth ,
227- { "current" : current_values_only },
228- { "features" : features , "start_time" : start_time , "end_time" : end_time })
233+ r = make_request (self ._FS_URL , Endpoints .TRAINING_SETS , RequestType .POST , self ._basic_auth ,
234+ { "current" : current_values_only , "label" : label , "pks" : return_pk_cols , "ts" : return_ts_col },
235+ { "features" : features , "start_time" : start_time , "end_time" : end_time })
229236 create_time = r ['metadata' ]['training_set_create_ts' ]
230-
237+ sql = r ['sql' ]
238+ tvw = TrainingView (** r ['training_view' ])
239+ features = [Feature (** f ) for f in r ['features' ]]
231240 # Here we create a null training view and pass it into the training set. We do this because this special kind
232241 # of training set isn't standard. It's not based on a training view, on primary key columns, a label column,
233242 # or a timestamp column . This is simply a joined set of features from different feature sets.
@@ -242,12 +251,12 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va
242251 # wants the most up to date values of each feature. So we set start_time to end_time (which is datetime.today)
243252
244253 if self .mlflow_ctx and not return_sql :
245- self .link_training_set_to_mlflow (features , create_time , start_time , end_time )
246- return r if return_sql else self .splice_ctx .df (r )
254+ self .link_training_set_to_mlflow (features , create_time , start_time , end_time , tvw )
255+ return sql if return_sql else self .splice_ctx .df (sql )
247256
248257 def get_training_set_from_view (self , training_view : str , features : Union [List [Feature ], List [str ]] = None ,
249258 start_time : Optional [datetime ] = None , end_time : Optional [datetime ] = None ,
250- return_sql : bool = False ) -> SparkDF or str :
259+ return_pk_cols : bool = False , return_ts_col : bool = False , return_sql : bool = False ) -> SparkDF or str :
251260 """
252261 Returns the training set as a Spark Dataframe from a Training View. When a user calls this function (assuming they have registered
253262 the feature store with mlflow using :py:meth:`~mlflow.register_feature_store` )
@@ -284,13 +293,16 @@ def get_training_set_from_view(self, training_view: str, features: Union[List[Fe
284293
285294 If end_time is None, query will get most recently available data
286295
296+ :param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
297+ :param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
287298 :param return_sql: (Optional[bool]) Return the SQL statement (str) instead of the Spark DF. Defaults False
288299 :return: Optional[SparkDF, str] The Spark dataframe of the training set or the SQL that is used to generate it (for debugging)
289300 """
290301
291302 # # Generate the SQL needed to create the dataset
292303 features = [f if isinstance (f , str ) else f .__dict__ for f in features ] if features else None
293- r = make_request (self ._FS_URL , Endpoints .TRAINING_SET_FROM_VIEW , RequestType .POST , self ._basic_auth , { "view" : training_view },
304+ r = make_request (self ._FS_URL , Endpoints .TRAINING_SET_FROM_VIEW , RequestType .POST , self ._basic_auth ,
305+ { "view" : training_view , "pks" : return_pk_cols , "ts" : return_ts_col },
294306 { "features" : features , "start_time" : start_time , "end_time" : end_time })
295307 sql = r ["sql" ]
296308 tvw = TrainingView (** r ["training_view" ])
@@ -518,19 +530,26 @@ def _training_view_describe(self, tcx: TrainingView, feats: List[Feature]):
518530 def set_feature_description (self ):
519531 raise NotImplementedError
520532
521- def get_training_set_from_deployment (self , schema_name : str , table_name : str ):
533+ def get_training_set_from_deployment (self , schema_name : str , table_name : str , label : str = None ,
534+ return_pk_cols : bool = False , return_ts_col : bool = False ):
522535 """
523536 Reads Feature Store metadata to rebuild orginal training data set used for the given deployed model.
524537 :param schema_name: model schema name
525538 :param table_name: model table name
539+ :param label: An optional label to specify for the training set. If specified, the feature set of that feature
540+ will be used as the "anchor" feature set, meaning all point in time joins will be made to the timestamps of
541+ that feature set. This feature will also be recorded as a "label" feature for this particular training set
542+ (but not others in the future, unless this label is again specified).
543+ :param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
544+ :param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
526545 :return:
527546 """
528547 # database stores object names in upper case
529548 schema_name = schema_name .upper ()
530549 table_name = table_name .upper ()
531550
532551 r = make_request (self ._FS_URL , Endpoints .TRAINING_SET_FROM_DEPLOYMENT , RequestType .GET , self ._basic_auth ,
533- { "schema" : schema_name , "table" : table_name })
552+ { "schema" : schema_name , "table" : table_name , "label" : label , "pks" : return_pk_cols , "ts" : return_ts_col })
534553
535554 metadata = r ['metadata' ]
536555 sql = r ['sql' ]
0 commit comments