Skip to content

Commit 727041a

Browse files
Fix mutual_information bug with string index (#1199)
* fix mi bug with string index * update release notes
1 parent 51fd68d commit 727041a

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

docs/source/release_notes.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ Future Release
77
==============
88
* Enhancements
99
* Fixes
10+
* Fix bug that causes ``mutual_information`` to fail with certain index types (:pr:`1199`)
1011
* Changes
1112
* Update pip to 21.3.1 for test requirements (:pr:`1196`)
1213
* Documentation Changes
1314
* Update install page with updated minimum optional dependencies (:pr:`1193`)
1415
* Testing Changes
1516

1617
Thanks to the following people for contributing to this release:
17-
:user:`gsheni`
18+
:user:`gsheni`, :user:`thehomebrewnerd`
1819

1920
v0.9.0 Nov 11, 2021
2021
===================

woodwork/statistics_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,9 @@ def _get_mutual_information_dict(
280280
if type(col.logical_type) in valid_types
281281
]
282282

283-
if not include_index and dataframe.ww.index is not None:
284-
valid_columns.remove(dataframe.ww.index)
283+
index = dataframe.ww.index
284+
if not include_index and index is not None and index in valid_columns:
285+
valid_columns.remove(index)
285286

286287
data = dataframe.loc[:, valid_columns]
287288
if _is_dask_dataframe(data):

woodwork/tests/accessor/test_statistics.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,23 @@ def test_get_valid_mi_columns_with_index(sample_df):
295295
assert "id" in mi
296296

297297

298+
def test_mutual_info_with_string_index():
299+
df = pd.DataFrame(
300+
{
301+
"id": ["a", "b", "c"],
302+
"col1": [1, 2, 3],
303+
"col2": [10, 20, 30],
304+
}
305+
)
306+
df.ww.init(index="id", logical_types={"id": "unknown"})
307+
mi = df.ww.mutual_information()
308+
309+
cols_used = set(np.unique(mi[["column_1", "column_2"]].values))
310+
assert "id" not in cols_used
311+
assert "col1" in cols_used
312+
assert "col2" in cols_used
313+
314+
298315
def test_get_describe_dict(describe_df):
299316
describe_df.ww.init(index="index_col")
300317

0 commit comments

Comments
 (0)