Skip to content

Commit 30a8b8c

Browse files
committed
work around change in sklearn api
1 parent 4338546 commit 30a8b8c

File tree

6 files changed

+127
-171
lines changed

6 files changed

+127
-171
lines changed

coverage.txt

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ tests/test_threshold_stats.py . [100%]
2727
/Users/johnmount/opt/anaconda3/envs/research_env/lib/python3.9/site-packages/setuptools/_distutils/version.py:351: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
2828
other = LooseVersion(other)
2929

30-
tests/test_onehot.py::test_onehot
31-
/Users/johnmount/opt/anaconda3/envs/research_env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function get_feature_names is deprecated; get_feature_names is deprecated in 1.0 and will be removed in 1.2. Please use get_feature_names_out instead.
32-
warnings.warn(msg, category=FutureWarning)
33-
3430
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
3531

3632
---------- coverage: platform darwin, python 3.9.12-final-0 ----------
@@ -40,8 +36,8 @@ wvpy/__init__.py 3 0 100%
4036
wvpy/jtools.py 136 34 75%
4137
wvpy/pysheet.py 33 33 0%
4238
wvpy/render_workbook.py 29 29 0%
43-
wvpy/util.py 322 7 98%
39+
wvpy/util.py 316 7 98%
4440
---------------------------------------------
45-
TOTAL 523 103 80%
41+
TOTAL 517 103 80%
4642

47-
======================= 18 passed, 3 warnings in 14.78s ========================
43+
======================= 18 passed, 2 warnings in 15.55s ========================

pkg/build/lib/wvpy/util.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -899,17 +899,7 @@ def fit_onehot_enc(
899899
categories="auto", drop=None, sparse=False, handle_unknown="ignore" # default
900900
)
901901
enc.fit(d[categorical_var_names])
902-
produced_column_names = list(enc.get_feature_names())
903-
# map back to original column names
904-
905-
def replace_col_name(v):
906-
"""Replace x[0-9]+_level with var_level"""
907-
v_prefix = re.sub(r"_.*$", "", v)
908-
v_suffix = re.sub(r"^x[0-9]+_", "", v)
909-
v_index = int(re.sub(r"^x", "", v_prefix))
910-
return f"{categorical_var_names[v_index]}_{v_suffix}"
911-
912-
produced_column_names = [replace_col_name(v) for v in produced_column_names]
902+
produced_column_names = list(enc.get_feature_names_out())
913903
# return the structure
914904
encoder_bundle = {
915905
"categorical_var_names": categorical_var_names,
-131 Bytes
Binary file not shown.

pkg/dist/wvpy-0.3.0.tar.gz

-108 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)