|
25 | 25 | class FeatureStore: |
26 | 26 | def __init__(self, splice_ctx: PySpliceContext) -> None: |
27 | 27 | self.splice_ctx = splice_ctx |
| 28 | + self.mlflow_ctx = None |
28 | 29 | self.feature_sets = [] # Cache of newly created feature sets |
29 | 30 |
|
30 | 31 | def register_splice_context(self, splice_ctx: PySpliceContext) -> None: |
@@ -166,7 +167,7 @@ def get_feature_vector(self, features: List[Union[str, Feature]], |
166 | 167 | Gets a feature vector given a list of Features and primary key values for their corresponding Feature Sets |
167 | 168 |
|
168 | 169 | :param features: List of str Feature names or Features |
169 | | - :param join_key_values: (dict) join key vals to get the proper Feature values formatted as {join_key_column_name: join_key_value} |
| 170 | + :param join_key_values: (dict) join key values to get the proper Feature values formatted as {join_key_column_name: join_key_value} |
170 | 171 | :param return_sql: Whether to return the SQL needed to get the vector or the values themselves. Default False |
171 | 172 | :return: Pandas Dataframe or str (SQL statement) |
172 | 173 | """ |
@@ -735,12 +736,16 @@ def __log_mlflow_results(self, name, rounds, mlflow_results): |
735 | 736 | :param name: MLflow run name |
736 | 737 | :param rounds: Number of rounds of feature elimination that were run |
737 | 738 | :param mlflow_results: The params / metrics to log |
738 | | - :return: |
739 | 739 | """ |
740 | | - with self.mlflow_ctx.start_run(run_name=name): |
| 740 | + try: |
| 741 | + if self.mlflow_ctx.active_run(): |
| 742 | + self.mlflow_ctx.start_run(run_name=name) |
741 | 743 | for r in range(rounds): |
742 | 744 | with self.mlflow_ctx.start_run(run_name=f'Round {r}', nested=True): |
743 | 745 | self.mlflow_ctx.log_metrics(mlflow_results[r]) |
| 746 | + finally: |
| 747 | + self.mlflow_ctx.end_run() |
| 748 | + |
744 | 749 |
|
745 | 750 | def __prune_features_for_elimination(self, features) -> List[Feature]: |
746 | 751 | """ |
|
0 commit comments