|
25 | 25 | class FeatureStore: |
26 | 26 | def __init__(self, splice_ctx: PySpliceContext) -> None: |
27 | 27 | self.splice_ctx = splice_ctx |
| 28 | + self.mlflow_ctx = None |
28 | 29 | self.feature_sets = [] # Cache of newly created feature sets |
29 | 30 |
|
30 | 31 | def register_splice_context(self, splice_ctx: PySpliceContext) -> None: |
@@ -166,7 +167,7 @@ def get_feature_vector(self, features: List[Union[str, Feature]], |
166 | 167 | Gets a feature vector given a list of Features and primary key values for their corresponding Feature Sets |
167 | 168 |
|
168 | 169 | :param features: List of str Feature names or Features |
169 | | - :param join_key_values: (dict) join key vals to get the proper Feature values formatted as {join_key_column_name: join_key_value} |
| 170 | + :param join_key_values: (dict) join key values to get the proper Feature values formatted as {join_key_column_name: join_key_value} |
170 | 171 | :param return_sql: Whether to return the SQL needed to get the vector or the values themselves. Default False |
171 | 172 | :return: Pandas Dataframe or str (SQL statement) |
172 | 173 | """ |
@@ -333,7 +334,7 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va |
333 | 334 | if current_values_only: |
334 | 335 | ts.start_time = ts.end_time |
335 | 336 |
|
336 | | - if hasattr(self, 'mlflow_ctx'): |
| 337 | + if self.mlflow_ctx and not return_sql: |
337 | 338 | self.mlflow_ctx._active_training_set: TrainingSet = ts |
338 | 339 | ts._register_metadata(self.mlflow_ctx) |
339 | 340 | return sql if return_sql else self.splice_ctx.df(sql) |
@@ -394,7 +395,7 @@ def get_training_set_from_view(self, training_view: str, features: Union[List[Fe |
394 | 395 | sql = _generate_training_set_history_sql(tvw, features, feature_sets, start_time=start_time, end_time=end_time) |
395 | 396 |
|
396 | 397 | # Link this to mlflow for model deployment |
397 | | - if hasattr(self, 'mlflow_ctx') and not return_sql: |
| 398 | + if self.mlflow_ctx and not return_sql: |
398 | 399 | ts = TrainingSet(training_view=tvw, features=features, |
399 | 400 | start_time=start_time, end_time=end_time) |
400 | 401 | self.mlflow_ctx._active_training_set: TrainingSet = ts |
@@ -735,12 +736,16 @@ def __log_mlflow_results(self, name, rounds, mlflow_results): |
735 | 736 | :param name: MLflow run name |
736 | 737 | :param rounds: Number of rounds of feature elimination that were run |
737 | 738 | :param mlflow_results: The params / metrics to log |
738 | | - :return: |
739 | 739 | """ |
740 | | - with self.mlflow_ctx.start_run(run_name=name): |
| 740 | + try: |
| 741 | + if self.mlflow_ctx.active_run(): |
| 742 | + self.mlflow_ctx.start_run(run_name=name) |
741 | 743 | for r in range(rounds): |
742 | 744 | with self.mlflow_ctx.start_run(run_name=f'Round {r}', nested=True): |
743 | 745 | self.mlflow_ctx.log_metrics(mlflow_results[r]) |
| 746 | + finally: |
| 747 | + self.mlflow_ctx.end_run() |
| 748 | + |
744 | 749 |
|
745 | 750 | def __prune_features_for_elimination(self, features) -> List[Feature]: |
746 | 751 | """ |
@@ -814,7 +819,7 @@ def run_feature_elimination(self, df, features: List[Union[str, Feature]], label |
814 | 819 | round_metrics[row['name']] = row['score'] |
815 | 820 | mlflow_results.append(round_metrics) |
816 | 821 |
|
817 | | - if log_mlflow and hasattr(self, 'mlflow_ctx'): |
| 822 | + if log_mlflow and self.mlflow_ctx: |
818 | 823 | run_name = mlflow_run_name or f'feature_elimination_{label}' |
819 | 824 | self.__log_mlflow_results(run_name, rnd, mlflow_results) |
820 | 825 |
|
|
0 commit comments