From 5cb4f54aaaa83815a905efe1b34b1c7bd887dd5a Mon Sep 17 00:00:00 2001 From: chenyangkang Date: Mon, 11 Nov 2024 10:39:26 -0600 Subject: [PATCH] update --- stemflow/model/AdaSTEM.py | 4 ++-- stemflow/model/static_func_AdaSTEM.py | 33 +++++++++++++++------------ stemflow/utils/plot_gif.py | 2 ++ 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/stemflow/model/AdaSTEM.py b/stemflow/model/AdaSTEM.py index 36819ff..f624f7f 100644 --- a/stemflow/model/AdaSTEM.py +++ b/stemflow/model/AdaSTEM.py @@ -275,7 +275,7 @@ def __init__( else: self.verbosity = 0 - def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None, n_jobs: int = 1): + def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None, n_jobs: Union[None, int] = None): """QuadTree indexing the input data Args: @@ -796,7 +796,7 @@ def predict_proba( X_test: pd.core.frame.DataFrame, verbosity: Union[int, None] = None, return_std: bool = False, - n_jobs: Union[None, int] = 1, + n_jobs: Union[None, int] = None, aggregation: str = "mean", return_by_separate_ensembles: bool = False, **base_model_prediction_param diff --git a/stemflow/model/static_func_AdaSTEM.py b/stemflow/model/static_func_AdaSTEM.py index 5007ef6..b97fa73 100644 --- a/stemflow/model/static_func_AdaSTEM.py +++ b/stemflow/model/static_func_AdaSTEM.py @@ -90,22 +90,25 @@ def train_one_stixel( sample_weights = class_weight.compute_sample_weight( class_weight="balanced", y=np.where(sub_y_train > 0, 1, 0) ) - - try: - trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train, sample_weight=sample_weights) - - except Exception as e: - print(e) - # raise - return (None, [], "Base_model_fitting_error(non-regression, balanced weight)") + trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train, sample_weight=sample_weights) + + # try: + # trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train, sample_weight=sample_weights) + + # except Exception as e: + # print(e) + # # raise + # return (None, [], "Base_model_fitting_error(non-regression, balanced weight)") else: - try: - trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train) - - except Exception as e: - print(e) - # raise - return (None, [], "Base_model_fitting_error(regression)") + trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train) + + # try: + # trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train) + + # except Exception as e: + # print(e) + # # raise + # return (None, [], "Base_model_fitting_error(regression)") return (trained_model, stixel_specific_x_names, "Success") diff --git a/stemflow/utils/plot_gif.py b/stemflow/utils/plot_gif.py index de91cde..ab4a678 100644 --- a/stemflow/utils/plot_gif.py +++ b/stemflow/utils/plot_gif.py @@ -122,6 +122,8 @@ def round_to_same_decimal_places(A, B): if log_scale else np.max(data[col].values) ) + + print(vmin, vmax) norm = Normalize(vmin=vmin, vmax=vmax) # Prepare colormap