Skip to content

Commit 731cb58

Browse files
authored
Merge pull request #22 from fidelity/feature/kl_divergence
KL Divergence Based Feature Selection
2 parents 7e07abc + b704d69 commit 731cb58

File tree

11 files changed

+200
-17
lines changed

11 files changed

+200
-17
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,14 @@ jobs:
1616
strategy:
1717
matrix:
1818
python-version: ["3.8", "3.9", "3.10"]
19-
os: [ubuntu-latest, macos-latest, windows-latest]
19+
os: [ubuntu-latest, windows-latest]
2020
fail-fast: false
2121
steps:
2222
- uses: actions/checkout@v2
2323
- name: Set up Python ${{ matrix.python-version }}
2424
uses: actions/setup-python@v2
2525
with:
2626
python-version: ${{ matrix.python-version }}
27-
2827
- name: Check
2928
shell: bash
3029
run: |

CHANGELOG.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
CHANGELOG
33
=========
44

5+
-------------------------------------------------------------------------------
6+
August 7, 2025 1.2.0
7+
-------------------------------------------------------------------------------
8+
9+
- Added KL Divergence based feature selection for binary labels. Thanks to @zohairshafi for contributing this method.
10+
511
-------------------------------------------------------------------------------
612
April, 24, 2023 1.1.2
713
-------------------------------------------------------------------------------

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ print("Scores:", list(selector.get_absolute_scores()))
4848
|:--------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
4949
| [Variance per Feature](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.VarianceThreshold.html) | `threshold` |
5050
| [Correlation pairwise Features](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.corr.html) | [Pearson Correlation Coefficient](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient) <br> [Kendall Rank Correlation Coefficient](https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient) <br> [Spearman's Rank Correlation Coefficient](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient) <br> |
51-
| [Statistical Analysis](https://scikit-learn.org/stable/modules/feature_selection.html#univariate-feature-selection) | [ANOVA F-test Classification](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_classif.html) <br> [F-value Regression](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_regression.html) <br> [Chi-Square](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.chi2.html) <br> [Mutual Information Classification](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.mutual_info_classif.html) <br> [Variance Inflation Factor](https://www.statsmodels.org/stable/generated/statsmodels.stats.outliers_influence.variance_inflation_factor.html) |
51+
| [Statistical Analysis](https://scikit-learn.org/stable/modules/feature_selection.html#univariate-feature-selection) | [ANOVA F-test Classification](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_classif.html) <br> [F-value Regression](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_regression.html) <br> [Chi-Square](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.chi2.html) <br> [KL Divergence](https://en.wikipedia.org/wiki/Kullback–Leibler_divergence) <br> [Mutual Information Classification](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.mutual_info_classif.html) <br> [Variance Inflation Factor](https://www.statsmodels.org/stable/generated/statsmodels.stats.outliers_influence.variance_inflation_factor.html) |
5252
| [Linear Methods](https://en.wikipedia.org/wiki/Linear_regression) | [Linear Regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html?highlight=linear%20regression#sklearn.linear_model.LinearRegression) <br> [Logistic Regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html?highlight=logistic%20regression#sklearn.linear_model.LogisticRegression) <br> [Lasso Regularization](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html#sklearn.linear_model.Lasso) <br> [Ridge Regularization](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html#sklearn.linear_model.Ridge) <br> |
5353
| [Tree-based Methods](https://scikit-learn.org/stable/modules/tree.html) | [Decision Tree](https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier) <br> [Random Forest](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html?highlight=random%20forest#sklearn.ensemble.RandomForestClassifier) <br> [Extra Trees Classifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html) <br> [XGBoost](https://xgboost.readthedocs.io/en/latest/) <br> [LightGBM](https://lightgbm.readthedocs.io/en/latest/) <br> [AdaBoost](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.AdaBoostClassifier.html) <br> [CatBoost](https://github.com/catboost)<br> [Gradient Boosting Tree](http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingClassifier.html) <br> |
5454
| [Text-based Methods](https://link.springer.com/chapter/10.1007/978-3-030-78230-6_27) | `featurization_method` = [TextWiser](https://github.com/fidelity/textwiser) <br> `optimization_method = ["exact", "greedy", "kmeans", "random"]` <br> `cost_metric = ["unicost", "diverse"]` |
@@ -81,6 +81,7 @@ selectors = {
8181
# Statistical methods
8282
"stat_anova": SelectionMethod.Statistical(num_features, method="anova"),
8383
"stat_chi_square": SelectionMethod.Statistical(num_features, method="chi_square"),
84+
"stat_kl_divergence": SelectionMethod.Statistical(num_features, method="kl_divergence"),
8485
"stat_mutual_info": SelectionMethod.Statistical(num_features, method="mutual_info"),
8586

8687
# Linear methods
@@ -168,7 +169,7 @@ plot_importance(df)
168169

169170
## Installation
170171

171-
Selective requires **Python 3.7+** and can be installed from PyPI using ``pip install selective``.
172+
Selective requires **Python 3.8+** and can be installed from PyPI using ``pip install selective``.
172173

173174
## Source
174175

feature/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# Copyright FMR LLC <[email protected]>
33
# SPDX-License-Identifier: Apache-2.0
44

5-
__version__ = "1.1.2"
5+
__version__ = "1.2.0"

feature/kl_divergence.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright FMR LLC <[email protected]>
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import NoReturn, Tuple
6+
7+
import pandas as pd
8+
import numpy as np
9+
10+
from scipy.special import rel_entr
11+
from feature.base import _BaseSupervisedSelector
12+
from feature.utils import Num, check_true
13+
14+
15+
class _KL_Divergence(_BaseSupervisedSelector):
16+
17+
def __init__(self, seed: int, num_features: Num, num_bins: Num = 100):
18+
super().__init__(seed)
19+
20+
self.num_features = num_features # this could be int or float
21+
self.num_bins = num_bins
22+
23+
def fit(self, X: pd.DataFrame, y: pd.Series) -> NoReturn:
24+
25+
label_categories = np.unique(y)
26+
check_true(len(label_categories) == 2, TypeError("Only binary labels are supported for KL Divergence"))
27+
input_dimension = X.shape[1]
28+
29+
kl_mat = np.zeros((input_dimension, 1))
30+
X = X.values
31+
32+
class_one_idx = np.where(y == label_categories[0])[0]
33+
class_two_idx = np.where(y == label_categories[1])[0]
34+
35+
for i in range(input_dimension):
36+
37+
# Create two distributions, one for the positive label and one for the negative label
38+
f1 = np.histogram(X[class_one_idx, i], bins = self.num_bins)[0]
39+
f2 = np.histogram(X[class_two_idx, i], bins = self.num_bins)[0]
40+
41+
# Normalize the distributions to be between 0 and 1
42+
f1 = f1 / np.sum(f1)
43+
f2 = f2 / np.sum(f2)
44+
45+
# KL Divergence is not symmetric, so we calculate divergence in both directions
46+
kl = rel_entr(f1, f2)
47+
kl_reversed = rel_entr(f2, f1)
48+
49+
# The relative entropy function returns KL(P || Q) = np.inf when P == 0 and Q != 0.
50+
kl[kl == np.inf] = 9999
51+
kl_reversed[kl_reversed == np.inf] = 9999
52+
53+
# The final score is the combination of KL divergence in both directions.
54+
# This could possibly be a flag in a future version to determine which direction to apply KL Divergence
55+
# in if bidirectional is not desired.
56+
kl_mat[i] = np.sum(kl) + np.sum(kl_reversed)
57+
58+
scores_ = kl_mat.flatten()
59+
60+
self.scores_ = scores_ # This is used by the statistical.py fit function.
61+
self.abs_scores = scores_
62+
63+
def transform(self, data: pd.DataFrame) -> pd.DataFrame:
64+
65+
# Select top-k from data based on abs_scores and num_features
66+
return self.get_top_k(data, self.abs_scores)

feature/selector.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737

3838
__author__ = "FMR LLC"
39-
__version__ = "1.0.0"
39+
__version__ = "1.2.0"
4040
__copyright__ = "Copyright (C), FMR LLC"
4141

4242

@@ -202,6 +202,13 @@ class Statistical(NamedTuple):
202202
searching for the optimal binning strategy.
203203
Note: MIC is dropped from Selective due to inactive MINE library
204204
205+
The KL Divergence feature importance should only be used with
206+
binary labels. It computes the distribution of a given feature for instances where label == 1 and label == 0.
207+
Uses KL divergence between the two distributions as an importance score,
208+
where a higher value indicates greater discriminative power of the feature
209+
with respect to the binary label. Since KL Divergence is non-symmetric, this method
210+
computer the divergence in both directions and sums them up.
211+
205212
Notes on Randomness:
206213
- Mutual Info is non-deterministic, depends on the seed value.
207214
- The other methods are deterministic
@@ -227,7 +234,7 @@ def _validate(self):
227234
if isinstance(self.num_features, float):
228235
check_true(self.num_features <= 1, ValueError("Num features ratio must be between [0..1]."))
229236
# "maximal_info" dropped
230-
check_true(self.method in ["anova", "chi_square", "mutual_info", "variance_inflation"],
237+
check_true(self.method in ["anova", "chi_square", "kl_divergence", "mutual_info", "variance_inflation"],
231238
ValueError("Statistical method can only be anova, chi_square, or mutual_info."))
232239

233240
class TreeBased(NamedTuple):

feature/statistical.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
import pandas as pd
1111
from sklearn.feature_selection import chi2, f_classif, f_regression, mutual_info_classif, mutual_info_regression
1212
from statsmodels.stats.outliers_influence import variance_inflation_factor
13+
from scipy.special import rel_entr
1314

1415
from feature.base import _BaseSupervisedSelector, _BaseDispatcher
1516
from feature.utils import get_selector, Num, get_task_string
17+
from feature.kl_divergence import _KL_Divergence
1618

1719

1820
class _Statistical(_BaseSupervisedSelector, _BaseDispatcher):
@@ -33,15 +35,17 @@ def __init__(self, seed: int, num_features: Num, method: str):
3335
self.imp = None
3436

3537
# Implementor factory
36-
self.factory = {"regression_anova": f_regression,
37-
"regression_chi_square": None,
38-
"regression_mutual_info": partial(mutual_info_regression, random_state=self.seed),
39-
# "regression_maximal_info": MINE(), # dropped
40-
"classification_anova": f_classif,
38+
self.factory = {"classification_anova": f_classif,
4139
"classification_chi_square": chi2,
4240
"classification_mutual_info": partial(mutual_info_classif, random_state=self.seed),
4341
# "classification_maximal_info": MINE(), # dropped
44-
"unsupervised_variance_inflation": variance_inflation_factor}
42+
"kl_divergence" : _KL_Divergence(num_features = self.num_features, seed = self.seed),
43+
"regression_anova": f_regression,
44+
"regression_chi_square": None,
45+
"regression_mutual_info": partial(mutual_info_regression, random_state=self.seed),
46+
# "regression_maximal_info": MINE(), # dropped
47+
"unsupervised_variance_inflation": variance_inflation_factor,
48+
}
4549

4650
def get_model_args(self, selection_method) -> Tuple:
4751

@@ -54,14 +58,18 @@ def dispatch_model(self, labels: pd.Series, *args):
5458
method = args[0]
5559

5660
# Get statistical scoring function
57-
if method == "variance_inflation":
61+
if method == "kl_divergence":
62+
score_func = self.factory.get(method)
63+
elif method == "variance_inflation":
5864
score_func = self.factory.get("unsupervised_" + method)
5965
else:
6066
score_func = self.factory.get(get_task_string(labels) + method)
6167

6268
# Check scoring compatibility with task
6369
if score_func is None:
6470
raise TypeError(method + " cannot be used for task: " + get_task_string(labels))
71+
elif method == "kl_divergence":
72+
self.imp = score_func
6573
elif method == "variance_inflation": # or isinstance(score_func, MINE) (dropped)
6674
self.imp = score_func
6775
else:
@@ -82,6 +90,7 @@ def fit(self, data: pd.DataFrame, labels: pd.Series) -> NoReturn:
8290
if self.method == "variance_inflation":
8391
# VIF is unsupervised, regression between data and each feature
8492
self.abs_scores = np.array([variance_inflation_factor(data.values, i) for i in range(data.shape[1])])
93+
8594
else:
8695
# sklearn selector model
8796
self.imp.fit(X=data, y=labels)

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
packages=setuptools.find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
2727
classifiers=[
2828
"License :: OSI Approved :: Apache Software License",
29-
"Programming Language :: Python :: 3.7",
29+
"Programming Language :: Python :: 3.8",
3030
"Operating System :: OS Independent",
3131
],
3232
project_urls={"Source": "https://github.com/fidelity/selective"},
3333
install_requires=required,
34-
python_requires=">=3.7"
34+
python_requires=">=3.8"
3535
)

tests/run_all.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import unittest
77

8-
98
# Test Directory
109
start_dir = '.'
1110

tests/test_benchmark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class TestBenchmark(BaseTest):
2929
"univ_anova": SelectionMethod.Statistical(num_features, method="anova"),
3030
"univ_chi_square": SelectionMethod.Statistical(num_features, method="chi_square"),
3131
"univ_mutual_info": SelectionMethod.Statistical(num_features, method="mutual_info"),
32+
"kl_divergence": SelectionMethod.Statistical(num_features, method="kl_divergence"),
3233
"linear": SelectionMethod.Linear(num_features, regularization="none"),
3334
"lasso": SelectionMethod.Linear(num_features, regularization="lasso", alpha=alpha),
3435
"ridge": SelectionMethod.Linear(num_features, regularization="ridge", alpha=alpha),

0 commit comments

Comments
 (0)