-
Notifications
You must be signed in to change notification settings - Fork 5
Description
Is there any way to use the tool with categorical/text features?
I'm using a fitted CatBoost model (that handles the feature's transformations itself). Since the method works on the shap values of the features, it should not need the raw features to be transformed. Nevertheless, I get an error when trying:
model = CatBoostClassifier( cat_features= categorical_cols ,text_features=text_cols) model.fit(X_train, y_train)
This works fine:
explainer = shap.Explainer(model)
shap_values = explainer.shap_values(Pool(X_train, y_train, cat_features=categorical_cols,text_features=text_cols))
shap.summary_plot(shap_values, X_train,)
Running shap select results in an error :
df_val = pd.concat( [X_test, y_test],axis=1)
from shap_select import shap_select
selected_features_df = shap_select(tree_model=clf_model, validation_df=df_val, target="y", task="binary", threshold=0.5)
Output error:
TypeError Traceback (most recent call last)
File _catboost.pyx:2547, in _catboost.get_float_feature()File _catboost.pyx:1226, in _catboost._FloatOrNan()
File _catboost.pyx:1021, in _catboost._FloatOrNanFromString()
TypeError: Cannot convert 'b'Secondary education teaching professionals'' to float
During handling of the above exception, another exception occurred:
CatBoostError Traceback (most recent call last)
Cell In[57], line 4
1 from shap_select import shap_select
2 # Here model is any model supported by the shap library, fitted on a different (train) dataset
3 # Task can be regression, binary, or multiclass
----> 4 selected_features_df = shap_select(tree_model=clf_model, validation_df=df_val, target="y", task="binary", threshold=0.5)File /opt/anaconda3/envs/MedRag/lib/python3.11/site-packages/shap_select/select.py:316, in shap_select(tree_model, validation_df, target, feature_names, task, threshold, return_extended_data, alpha)
312 shap_features = create_shap_features(
313 tree_model, validation_df[feature_names], unique_classes
314 )
315 else:
--> 316 shap_features = create_shap_features(tree_model, validation_df[feature_names])
318 # Compute statistical significance of each feature, recursively ablating
319 significance_df = iterative_shap_feature_reduction(
320 shap_features, target, task, alpha
321 )File /opt/anaconda3/envs/MedRag/lib/python3.11/site-packages/shap_select/select.py:24, in create_shap_features(tree_model, validation_df, classes)
9 def create_shap_features(
10 tree_model: Any, validation_df: pd.DataFrame, classes: List | None = None
11 ) -> pd.DataFrame | Dict[Any, pd.DataFrame]:
12 """
13 Generates SHAP (SHapley Additive exPlanations) values for a given tree-based model on a validation dataset.
14
(...)
22 corresponds to the SHAP values of a feature, and the rows match the index of thevalidation_df
.
23 """
---> 24 explainer = shap.Explainer(tree_model, model_output="raw")(validation_df)
25 shap_values = explainer.values
27 if len(shap_values.shape) == 2:File /opt/anaconda3/envs/MedRag/lib/python3.11/site-packages/shap/explainers/_tree.py:262, in TreeExplainer.call(self, X, y, interactions, check_additivity)
259 feature_names = getattr(self, "data_feature_names", None)
261 if not interactions:
--> 262 v = self.shap_values(X, y=y, from_call=True, check_additivity=check_additivity, approximate=self.approximate)
263 if isinstance(v, list):
264 v = np.stack(v, axis=-1) # put outputs at the endFile /opt/anaconda3/envs/MedRag/lib/python3.11/site-packages/shap/explainers/_tree.py:464, in TreeExplainer.shap_values(self, X, y, tree_limit, approximate, check_additivity, from_call)
462 import catboost
463 if type(X) != catboost.Pool:
--> 464 X = catboost.Pool(X, cat_features=self.model.cat_feature_indices)
465 phi = self.model.original_model.get_feature_importance(data=X, fstr_type='ShapValues')
467 # note we pull off the last column and keep it as our expected_valueFile /opt/anaconda3/envs/MedRag/lib/python3.11/site-packages/catboost/core.py:855, in Pool.init(self, data, label, cat_features, text_features, embedding_features, embedding_features_data, column_description, pairs, graph, delimiter, has_header, ignore_csv_quoting, weight, group_id, group_weight, subgroup_id, pairs_weight, baseline, timestamp, feature_names, feature_tags, thread_count, log_cout, log_cerr, data_can_be_none)
849 if isinstance(feature_names, PATH_TYPES):
850 raise CatBoostError(
851 "feature_names must be None or have non-string type when the pool is created from "
852 "python objects."
853 )
--> 855 self._init(data, label, cat_features, text_features, embedding_features, embedding_features_data, pairs, graph, weight,
856 group_id, group_weight, subgroup_id, pairs_weight, baseline, timestamp, feature_names, feature_tags, thread_count)
857 elif not data_can_be_none:
858 raise CatBoostError("'data' parameter can't be None")File /opt/anaconda3/envs/MedRag/lib/python3.11/site-packages/catboost/core.py:1491, in Pool._init(self, data, label, cat_features, text_features, embedding_features, embedding_features_data, pairs, graph, weight, group_id, group_weight, subgroup_id, pairs_weight, baseline, timestamp, feature_names, feature_tags, thread_count)
1489 if feature_tags is not None:
1490 feature_tags = self._check_transform_tags(feature_tags, feature_names)
-> 1491 self._init_pool(data, label, cat_features, text_features, embedding_features, embedding_features_data, pairs, graph, weight,
1492 group_id, group_weight, subgroup_id, pairs_weight, baseline, timestamp, feature_names, feature_tags, thread_count)File _catboost.pyx:4339, in _catboost._PoolBase._init_pool()
File _catboost.pyx:4391, in _catboost._PoolBase._init_pool()
File _catboost.pyx:4200, in _catboost._PoolBase._init_features_order_layout_pool()
File _catboost.pyx:3127, in _catboost._set_features_order_data_pd_data_frame()
File _catboost.pyx:2591, in _catboost.create_num_factor_data()
File _catboost.pyx:2549, in _catboost.get_float_feature()
CatBoostError: Bad value for num_feature[non_default_doc_idx=0,feature_idx=2]="Secondary education teaching professionals": Cannot convert 'b'Secondary education teaching professionals'' to float