@@ -192,7 +192,8 @@ def get_feature_description(self):
192192 raise NotImplementedError
193193
194194 def get_training_set (self , features : Union [List [Feature ], List [str ]], current_values_only : bool = False ,
195- start_time : datetime = None , end_time : datetime = None , return_sql : bool = False ) -> SparkDF :
195+ start_time : datetime = None , end_time : datetime = None , label : str = None , return_pk_cols : bool = False ,
196+ return_ts_col : bool = False , return_sql : bool = False ) -> SparkDF or str :
196197 """
197198 Gets a set of feature values across feature sets that is not time dependent (ie for non time series clustering).
198199 This feature dataset will be treated and tracked implicitly the same way a training_dataset is tracked from
@@ -222,13 +223,22 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va
222223 were updated after this point in time won't be selected. If not specified (and current_values_only is False),
223224 Feature values up to the moment in time you call the function (now) will be retrieved. This parameter
224225 only takes effect if current_values_only is False.
226+ :param label: An optional label to specify for the training set. If specified, the feature set of that feature
227+ will be used as the "anchor" feature set, meaning all point in time joins will be made to the timestamps of
228+ that feature set. This feature will also be recorded as a "label" feature for this particular training set
229+ (but not others in the future, unless this label is again specified).
230+ :param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
231+ :param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
225232 :return: Spark DF
226233 """
227234 features = [f if isinstance (f , str ) else f .__dict__ for f in features ]
228- r = make_request (self ._FS_URL , Endpoints .TRAINING_SETS , RequestType .POST , self ._basic_auth , { "current" : current_values_only },
235+ r = make_request (self ._FS_URL , Endpoints .TRAINING_SETS , RequestType .POST , self ._basic_auth ,
236+ { "current" : current_values_only , "label" : label , "pks" : return_pk_cols , "ts" : return_ts_col },
229237 { "features" : features , "start_time" : start_time , "end_time" : end_time })
230238 create_time = r ['metadata' ]['training_set_create_ts' ]
231-
239+ sql = r ['sql' ]
240+ tvw = TrainingView (** r ['training_view' ])
241+ features = [Feature (** f ) for f in r ['features' ]]
232242 # Here we create a null training view and pass it into the training set. We do this because this special kind
233243 # of training set isn't standard. It's not based on a training view, on primary key columns, a label column,
234244 # or a timestamp column . This is simply a joined set of features from different feature sets.
@@ -243,12 +253,12 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va
243253 # wants the most up to date values of each feature. So we set start_time to end_time (which is datetime.today)
244254
245255 if self .mlflow_ctx and not return_sql :
246- self .link_training_set_to_mlflow (features , create_time , start_time , end_time )
247- return r if return_sql else self .splice_ctx .df (r )
256+ self .link_training_set_to_mlflow (features , create_time , start_time , end_time , tvw )
257+ return sql if return_sql else self .splice_ctx .df (sql )
248258
249259 def get_training_set_from_view (self , training_view : str , features : Union [List [Feature ], List [str ]] = None ,
250260 start_time : Optional [datetime ] = None , end_time : Optional [datetime ] = None ,
251- return_sql : bool = False ) -> SparkDF or str :
261+ return_pk_cols : bool = False , return_ts_col : bool = False , return_sql : bool = False ) -> SparkDF or str :
252262 """
253263 Returns the training set as a Spark Dataframe from a Training View. When a user calls this function (assuming they have registered
254264 the feature store with mlflow using :py:meth:`~mlflow.register_feature_store` )
@@ -285,13 +295,16 @@ def get_training_set_from_view(self, training_view: str, features: Union[List[Fe
285295
286296 If end_time is None, query will get most recently available data
287297
298+ :param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
299+ :param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
288300 :param return_sql: (Optional[bool]) Return the SQL statement (str) instead of the Spark DF. Defaults False
289301 :return: Optional[SparkDF, str] The Spark dataframe of the training set or the SQL that is used to generate it (for debugging)
290302 """
291303
292304 # # Generate the SQL needed to create the dataset
293305 features = [f if isinstance (f , str ) else f .__dict__ for f in features ] if features else None
294- r = make_request (self ._FS_URL , Endpoints .TRAINING_SET_FROM_VIEW , RequestType .POST , self ._basic_auth , { "view" : training_view },
306+ r = make_request (self ._FS_URL , Endpoints .TRAINING_SET_FROM_VIEW , RequestType .POST , self ._basic_auth ,
307+ { "view" : training_view , "pks" : return_pk_cols , "ts" : return_ts_col },
295308 { "features" : features , "start_time" : start_time , "end_time" : end_time })
296309 sql = r ["sql" ]
297310 tvw = TrainingView (** r ["training_view" ])
@@ -519,19 +532,26 @@ def _training_view_describe(self, tcx: TrainingView, feats: List[Feature]):
519532 def set_feature_description (self ):
520533 raise NotImplementedError
521534
522- def get_training_set_from_deployment (self , schema_name : str , table_name : str ):
535+ def get_training_set_from_deployment (self , schema_name : str , table_name : str , label : str = None ,
536+ return_pk_cols : bool = False , return_ts_col : bool = False ):
523537 """
524538 Reads Feature Store metadata to rebuild orginal training data set used for the given deployed model.
525539 :param schema_name: model schema name
526540 :param table_name: model table name
541+ :param label: An optional label to specify for the training set. If specified, the feature set of that feature
542+ will be used as the "anchor" feature set, meaning all point in time joins will be made to the timestamps of
543+ that feature set. This feature will also be recorded as a "label" feature for this particular training set
544+ (but not others in the future, unless this label is again specified).
545+ :param return_pk_cols: bool Whether or not the returned sql should include the primary key column(s)
546+ :param return_ts_cols: bool Whether or not the returned sql should include the timestamp column
527547 :return:
528548 """
529549 # database stores object names in upper case
530550 schema_name = schema_name .upper ()
531551 table_name = table_name .upper ()
532552
533553 r = make_request (self ._FS_URL , Endpoints .TRAINING_SET_FROM_DEPLOYMENT , RequestType .GET , self ._basic_auth ,
534- { "schema" : schema_name , "table" : table_name })
554+ { "schema" : schema_name , "table" : table_name , "label" : label , "pks" : return_pk_cols , "ts" : return_ts_col })
535555
536556 metadata = r ['metadata' ]
537557 sql = r ['sql' ]
0 commit comments