Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Nov 11, 2024
1 parent 2bfa3e3 commit 5cb4f54
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 17 deletions.
4 changes: 2 additions & 2 deletions stemflow/model/AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(
else:
self.verbosity = 0

def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None, n_jobs: int = 1):
def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None, n_jobs: Union[None, int] = None):
"""QuadTree indexing the input data
Args:
Expand Down Expand Up @@ -796,7 +796,7 @@ def predict_proba(
X_test: pd.core.frame.DataFrame,
verbosity: Union[int, None] = None,
return_std: bool = False,
n_jobs: Union[None, int] = 1,
n_jobs: Union[None, int] = None,
aggregation: str = "mean",
return_by_separate_ensembles: bool = False,
**base_model_prediction_param
Expand Down
33 changes: 18 additions & 15 deletions stemflow/model/static_func_AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,25 @@ def train_one_stixel(
sample_weights = class_weight.compute_sample_weight(
class_weight="balanced", y=np.where(sub_y_train > 0, 1, 0)
)

try:
trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train, sample_weight=sample_weights)

except Exception as e:
print(e)
# raise
return (None, [], "Base_model_fitting_error(non-regression, balanced weight)")
trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train, sample_weight=sample_weights)

# try:
# trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train, sample_weight=sample_weights)

# except Exception as e:
# print(e)
# # raise
# return (None, [], "Base_model_fitting_error(non-regression, balanced weight)")
else:
try:
trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train)

except Exception as e:
print(e)
# raise
return (None, [], "Base_model_fitting_error(regression)")
trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train)

# try:
# trained_model.fit(sub_X_train[stixel_specific_x_names], sub_y_train)

# except Exception as e:
# print(e)
# # raise
# return (None, [], "Base_model_fitting_error(regression)")

return (trained_model, stixel_specific_x_names, "Success")

Expand Down
2 changes: 2 additions & 0 deletions stemflow/utils/plot_gif.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def round_to_same_decimal_places(A, B):
if log_scale
else np.max(data[col].values)
)

print(vmin, vmax)
norm = Normalize(vmin=vmin, vmax=vmax)

# Prepare colormap
Expand Down

0 comments on commit 5cb4f54

Please sign in to comment.