1010import pandas as pd
1111from sklearn .feature_selection import chi2 , f_classif , f_regression , mutual_info_classif , mutual_info_regression
1212from statsmodels .stats .outliers_influence import variance_inflation_factor
13+ from scipy .special import rel_entr
1314
1415from feature .base import _BaseSupervisedSelector , _BaseDispatcher
1516from feature .utils import get_selector , Num , get_task_string
17+ from feature .kl_divergence import _KL_Divergence
1618
1719
1820class _Statistical (_BaseSupervisedSelector , _BaseDispatcher ):
@@ -33,15 +35,17 @@ def __init__(self, seed: int, num_features: Num, method: str):
3335 self .imp = None
3436
3537 # Implementor factory
36- self .factory = {"regression_anova" : f_regression ,
37- "regression_chi_square" : None ,
38- "regression_mutual_info" : partial (mutual_info_regression , random_state = self .seed ),
39- # "regression_maximal_info": MINE(), # dropped
40- "classification_anova" : f_classif ,
38+ self .factory = {"classification_anova" : f_classif ,
4139 "classification_chi_square" : chi2 ,
4240 "classification_mutual_info" : partial (mutual_info_classif , random_state = self .seed ),
4341 # "classification_maximal_info": MINE(), # dropped
44- "unsupervised_variance_inflation" : variance_inflation_factor }
42+ "kl_divergence" : _KL_Divergence (num_features = self .num_features , seed = self .seed ),
43+ "regression_anova" : f_regression ,
44+ "regression_chi_square" : None ,
45+ "regression_mutual_info" : partial (mutual_info_regression , random_state = self .seed ),
46+ # "regression_maximal_info": MINE(), # dropped
47+ "unsupervised_variance_inflation" : variance_inflation_factor ,
48+ }
4549
4650 def get_model_args (self , selection_method ) -> Tuple :
4751
@@ -54,14 +58,18 @@ def dispatch_model(self, labels: pd.Series, *args):
5458 method = args [0 ]
5559
5660 # Get statistical scoring function
57- if method == "variance_inflation" :
61+ if method == "kl_divergence" :
62+ score_func = self .factory .get (method )
63+ elif method == "variance_inflation" :
5864 score_func = self .factory .get ("unsupervised_" + method )
5965 else :
6066 score_func = self .factory .get (get_task_string (labels ) + method )
6167
6268 # Check scoring compatibility with task
6369 if score_func is None :
6470 raise TypeError (method + " cannot be used for task: " + get_task_string (labels ))
71+ elif method == "kl_divergence" :
72+ self .imp = score_func
6573 elif method == "variance_inflation" : # or isinstance(score_func, MINE) (dropped)
6674 self .imp = score_func
6775 else :
@@ -82,6 +90,7 @@ def fit(self, data: pd.DataFrame, labels: pd.Series) -> NoReturn:
8290 if self .method == "variance_inflation" :
8391 # VIF is unsupervised, regression between data and each feature
8492 self .abs_scores = np .array ([variance_inflation_factor (data .values , i ) for i in range (data .shape [1 ])])
93+
8594 else :
8695 # sklearn selector model
8796 self .imp .fit (X = data , y = labels )
0 commit comments