Skip to content

Commit ed1442e

Browse files
authored
Make sklearn dependency optional (keras-team#20657)
1 parent bce0f5b commit ed1442e

File tree

5 files changed

+104
-66
lines changed

5 files changed

+104
-66
lines changed

keras/src/wrappers/fixes.py

+8-44
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,7 @@
1-
import sklearn
2-
from packaging.version import parse as parse_version
3-
from sklearn import get_config
4-
5-
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
6-
7-
if sklearn_version < parse_version("1.6"):
8-
9-
def patched_more_tags(estimator, expected_failed_checks):
10-
import copy
11-
12-
from sklearn.utils._tags import _safe_tags
13-
14-
original_tags = copy.deepcopy(_safe_tags(estimator))
15-
16-
def patched_more_tags(self):
17-
original_tags.update({"_xfail_checks": expected_failed_checks})
18-
return original_tags
19-
20-
estimator.__class__._more_tags = patched_more_tags
21-
return estimator
22-
23-
def parametrize_with_checks(
24-
estimators,
25-
*,
26-
legacy: bool = True,
27-
expected_failed_checks=None,
28-
):
29-
# legacy is not supported and ignored
30-
from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001
31-
32-
estimators = [
33-
patched_more_tags(estimator, expected_failed_checks(estimator))
34-
for estimator in estimators
35-
]
36-
37-
return parametrize_with_checks(estimators)
38-
else:
39-
from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001
1+
try:
2+
import sklearn
3+
except ImportError:
4+
sklearn = None
405

416

427
def _validate_data(estimator, *args, **kwargs):
@@ -59,9 +24,6 @@ def _validate_data(estimator, *args, **kwargs):
5924

6025

6126
def type_of_target(y, input_name="", *, raise_unknown=False):
62-
# fix for raise_unknown which is introduced in scikit-learn 1.6
63-
from sklearn.utils.multiclass import type_of_target
64-
6527
def _raise_or_return(target_type):
6628
"""Depending on the value of raise_unknown, either raise an error or
6729
return 'unknown'.
@@ -72,7 +34,9 @@ def _raise_or_return(target_type):
7234
else:
7335
return target_type
7436

75-
target_type = type_of_target(y, input_name=input_name)
37+
target_type = sklearn.utils.multiclass.type_of_target(
38+
y, input_name=input_name
39+
)
7640
return _raise_or_return(target_type)
7741

7842

@@ -86,7 +50,7 @@ def _routing_enabled():
8650
8751
TODO: remove when the config key is no longer available in scikit-learn
8852
"""
89-
return get_config().get("enable_metadata_routing", False)
53+
return sklearn.get_config().get("enable_metadata_routing", False)
9054

9155

9256
def _raise_for_params(params, owner, method):

keras/src/wrappers/sklearn_test.py

+43-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from contextlib import contextmanager
44

55
import pytest
6+
import sklearn
7+
from packaging.version import parse as parse_version
8+
from sklearn.utils.estimator_checks import parametrize_with_checks
69

710
import keras
811
from keras.src.backend import floatx
@@ -13,7 +16,45 @@
1316
from keras.src.wrappers import SKLearnClassifier
1417
from keras.src.wrappers import SKLearnRegressor
1518
from keras.src.wrappers import SKLearnTransformer
16-
from keras.src.wrappers.fixes import parametrize_with_checks
19+
20+
21+
def wrapped_parametrize_with_checks(
22+
estimators,
23+
*,
24+
legacy: bool = True,
25+
expected_failed_checks=None,
26+
):
27+
"""Wrapped `parametrize_with_checks` handling backwards compat."""
28+
sklearn_version = parse_version(
29+
parse_version(sklearn.__version__).base_version
30+
)
31+
32+
if sklearn_version >= parse_version("1.6"):
33+
return parametrize_with_checks(
34+
estimators,
35+
legacy=legacy,
36+
expected_failed_checks=expected_failed_checks,
37+
)
38+
39+
def patched_more_tags(estimator, expected_failed_checks):
40+
import copy
41+
42+
original_tags = copy.deepcopy(sklearn.utils._tags._safe_tags(estimator))
43+
44+
def patched_more_tags(self):
45+
original_tags.update({"_xfail_checks": expected_failed_checks})
46+
return original_tags
47+
48+
estimator.__class__._more_tags = patched_more_tags
49+
return estimator
50+
51+
estimators = [
52+
patched_more_tags(estimator, expected_failed_checks(estimator))
53+
for estimator in estimators
54+
]
55+
56+
# legacy is not supported and ignored
57+
return parametrize_with_checks(estimators)
1758

1859

1960
def dynamic_model(X, y, loss, layers=[10]):
@@ -80,7 +121,7 @@ def use_floatx(x: str):
80121
}
81122

82123

83-
@parametrize_with_checks(
124+
@wrapped_parametrize_with_checks(
84125
estimators=[
85126
SKLearnClassifier(
86127
model=dynamic_model,

keras/src/wrappers/sklearn_wrapper.py

+31-13
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
import copy
22

33
import numpy as np
4-
from sklearn.base import BaseEstimator
5-
from sklearn.base import ClassifierMixin
6-
from sklearn.base import RegressorMixin
7-
from sklearn.base import TransformerMixin
8-
from sklearn.base import check_is_fitted
9-
from sklearn.pipeline import make_pipeline
10-
from sklearn.preprocessing import OneHotEncoder
11-
from sklearn.utils.metadata_routing import MetadataRequest
124

135
from keras.src.api_export import keras_export
146
from keras.src.models.cloning import clone_model
@@ -18,6 +10,28 @@
1810
from keras.src.wrappers.fixes import type_of_target
1911
from keras.src.wrappers.utils import TargetReshaper
2012
from keras.src.wrappers.utils import _check_model
13+
from keras.src.wrappers.utils import assert_sklearn_installed
14+
15+
try:
16+
import sklearn
17+
from sklearn.base import BaseEstimator
18+
from sklearn.base import ClassifierMixin
19+
from sklearn.base import RegressorMixin
20+
from sklearn.base import TransformerMixin
21+
except ImportError:
22+
sklearn = None
23+
24+
class BaseEstimator:
25+
pass
26+
27+
class ClassifierMixin:
28+
pass
29+
30+
class RegressorMixin:
31+
pass
32+
33+
class TransformerMixin:
34+
pass
2135

2236

2337
class SKLBase(BaseEstimator):
@@ -64,6 +78,7 @@ def __init__(
6478
model_kwargs=None,
6579
fit_kwargs=None,
6680
):
81+
assert_sklearn_installed(self.__class__.__name__)
6782
self.model = model
6883
self.warm_start = warm_start
6984
self.model_kwargs = model_kwargs
@@ -119,7 +134,9 @@ def set_fit_request(self, **kwargs):
119134
"sklearn.set_config(enable_metadata_routing=True)."
120135
)
121136

122-
self._metadata_request = MetadataRequest(owner=self.__class__.__name__)
137+
self._metadata_request = sklearn.utils.metadata_routing.MetadataRequest(
138+
owner=self.__class__.__name__
139+
)
123140
for param, alias in kwargs.items():
124141
self._metadata_request.score.add_request(param=param, alias=alias)
125142
return self
@@ -155,7 +172,7 @@ def fit(self, X, y, **kwargs):
155172

156173
def predict(self, X):
157174
"""Predict using the model."""
158-
check_is_fitted(self)
175+
sklearn.base.check_is_fitted(self)
159176
X = _validate_data(self, X, reset=False)
160177
raw_output = self.model_.predict(X)
161178
return self._reverse_process_target(raw_output)
@@ -267,8 +284,9 @@ def _process_target(self, y, reset=False):
267284
f" Target type: {target_type}"
268285
)
269286
if reset:
270-
self._target_encoder = make_pipeline(
271-
TargetReshaper(), OneHotEncoder(sparse_output=False)
287+
self._target_encoder = sklearn.pipeline.make_pipeline(
288+
TargetReshaper(),
289+
sklearn.preprocessing.OneHotEncoder(sparse_output=False),
272290
).fit(y)
273291
self.classes_ = np.unique(y)
274292
if len(self.classes_) == 1:
@@ -454,7 +472,7 @@ def transform(self, X):
454472
X_transformed: array-like, shape=(n_samples, n_features)
455473
The transformed data.
456474
"""
457-
check_is_fitted(self)
475+
sklearn.base.check_is_fitted(self)
458476
X = _validate_data(self, X, reset=False)
459477
return self.model_.predict(X)
460478

keras/src/wrappers/utils.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
1-
from sklearn.base import BaseEstimator
2-
from sklearn.base import TransformerMixin
3-
from sklearn.base import check_is_fitted
4-
from sklearn.utils._array_api import get_namespace
1+
try:
2+
import sklearn
3+
from sklearn.base import BaseEstimator
4+
from sklearn.base import TransformerMixin
5+
except ImportError:
6+
sklearn = None
7+
8+
class BaseEstimator:
9+
pass
10+
11+
class TransformerMixin:
12+
pass
13+
14+
15+
def assert_sklearn_installed(symbol_name):
16+
if sklearn is None:
17+
raise ImportError(
18+
f"{symbol_name} requires `scikit-learn` to be installed. "
19+
"Run `pip install scikit-learn` to install it."
20+
)
521

622

723
def _check_model(model):
@@ -64,8 +80,8 @@ def inverse_transform(self, y):
6480
is passed, it will be squeezed back to 1D. Otherwise, it
6581
will eb left untouched.
6682
"""
67-
check_is_fitted(self)
68-
xp, _ = get_namespace(y)
83+
sklearn.base.check_is_fitted(self)
84+
xp, _ = sklearn.utils._array_api.get_namespace(y)
6985
if self.ndim_ == 1 and y.ndim == 2:
7086
return xp.squeeze(y, axis=1)
7187
return y

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ dependencies = [
3434
"optree",
3535
"ml-dtypes",
3636
"packaging",
37-
"scikit-learn",
3837
]
3938
# Run also: pip install -r requirements.txt
4039

0 commit comments

Comments
 (0)