|
1 | 1 | import copy
|
2 | 2 |
|
3 | 3 | import numpy as np
|
4 |
| -from sklearn.base import BaseEstimator |
5 |
| -from sklearn.base import ClassifierMixin |
6 |
| -from sklearn.base import RegressorMixin |
7 |
| -from sklearn.base import TransformerMixin |
8 |
| -from sklearn.base import check_is_fitted |
9 |
| -from sklearn.pipeline import make_pipeline |
10 |
| -from sklearn.preprocessing import OneHotEncoder |
11 |
| -from sklearn.utils.metadata_routing import MetadataRequest |
12 | 4 |
|
13 | 5 | from keras.src.api_export import keras_export
|
14 | 6 | from keras.src.models.cloning import clone_model
|
|
18 | 10 | from keras.src.wrappers.fixes import type_of_target
|
19 | 11 | from keras.src.wrappers.utils import TargetReshaper
|
20 | 12 | from keras.src.wrappers.utils import _check_model
|
| 13 | +from keras.src.wrappers.utils import assert_sklearn_installed |
| 14 | + |
| 15 | +try: |
| 16 | + import sklearn |
| 17 | + from sklearn.base import BaseEstimator |
| 18 | + from sklearn.base import ClassifierMixin |
| 19 | + from sklearn.base import RegressorMixin |
| 20 | + from sklearn.base import TransformerMixin |
| 21 | +except ImportError: |
| 22 | + sklearn = None |
| 23 | + |
| 24 | + class BaseEstimator: |
| 25 | + pass |
| 26 | + |
| 27 | + class ClassifierMixin: |
| 28 | + pass |
| 29 | + |
| 30 | + class RegressorMixin: |
| 31 | + pass |
| 32 | + |
| 33 | + class TransformerMixin: |
| 34 | + pass |
21 | 35 |
|
22 | 36 |
|
23 | 37 | class SKLBase(BaseEstimator):
|
@@ -64,6 +78,7 @@ def __init__(
|
64 | 78 | model_kwargs=None,
|
65 | 79 | fit_kwargs=None,
|
66 | 80 | ):
|
| 81 | + assert_sklearn_installed(self.__class__.__name__) |
67 | 82 | self.model = model
|
68 | 83 | self.warm_start = warm_start
|
69 | 84 | self.model_kwargs = model_kwargs
|
@@ -119,7 +134,9 @@ def set_fit_request(self, **kwargs):
|
119 | 134 | "sklearn.set_config(enable_metadata_routing=True)."
|
120 | 135 | )
|
121 | 136 |
|
122 |
| - self._metadata_request = MetadataRequest(owner=self.__class__.__name__) |
| 137 | + self._metadata_request = sklearn.utils.metadata_routing.MetadataRequest( |
| 138 | + owner=self.__class__.__name__ |
| 139 | + ) |
123 | 140 | for param, alias in kwargs.items():
|
124 | 141 | self._metadata_request.score.add_request(param=param, alias=alias)
|
125 | 142 | return self
|
@@ -155,7 +172,7 @@ def fit(self, X, y, **kwargs):
|
155 | 172 |
|
156 | 173 | def predict(self, X):
|
157 | 174 | """Predict using the model."""
|
158 |
| - check_is_fitted(self) |
| 175 | + sklearn.base.check_is_fitted(self) |
159 | 176 | X = _validate_data(self, X, reset=False)
|
160 | 177 | raw_output = self.model_.predict(X)
|
161 | 178 | return self._reverse_process_target(raw_output)
|
@@ -267,8 +284,9 @@ def _process_target(self, y, reset=False):
|
267 | 284 | f" Target type: {target_type}"
|
268 | 285 | )
|
269 | 286 | if reset:
|
270 |
| - self._target_encoder = make_pipeline( |
271 |
| - TargetReshaper(), OneHotEncoder(sparse_output=False) |
| 287 | + self._target_encoder = sklearn.pipeline.make_pipeline( |
| 288 | + TargetReshaper(), |
| 289 | + sklearn.preprocessing.OneHotEncoder(sparse_output=False), |
272 | 290 | ).fit(y)
|
273 | 291 | self.classes_ = np.unique(y)
|
274 | 292 | if len(self.classes_) == 1:
|
@@ -454,7 +472,7 @@ def transform(self, X):
|
454 | 472 | X_transformed: array-like, shape=(n_samples, n_features)
|
455 | 473 | The transformed data.
|
456 | 474 | """
|
457 |
| - check_is_fitted(self) |
| 475 | + sklearn.base.check_is_fitted(self) |
458 | 476 | X = _validate_data(self, X, reset=False)
|
459 | 477 | return self.model_.predict(X)
|
460 | 478 |
|
|
0 commit comments