Skip to content

Commit ad2793e

Browse files
committed
Support better native support for SKlearn
- Remove SKlearnCompat wrapper - The make_sklearn_compat and .to_lale() methods are marked as deprecated and will be removed soon. They are now identity functions - get_params is now supported for Operators (at least for deep=False) - sklearn clone is now supported for Operators - set_params is now supported for TrainableOperators and BasePipelines - with_params, an immutable variant of set_params is now supported for all Operators and the recommended alternative - Implementation of Operator: - _impl and _impl_class_ have changed a bit. Make sure that you use the accessor functions to access them and instantiate them as needed. - _hyperparams may now store values that were not explicitly set by the user, with _frozen_hyperparams marking which keys were explicitly set. Use the various hyperparams() accessor methods to get the desired set of values. - Since one of our tests (TestAutoAIOutputConsumption) currently uses an old pickle, code was temporarily added to migrate this pickle to the new format.
1 parent 77230de commit ad2793e

24 files changed

+1328
-956
lines changed

examples/docs_guide_for_sklearn_users.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -1262,7 +1262,7 @@
12621262
],
12631263
"source": [
12641264
"from lale.pretty_print import ipython_display\n",
1265-
"ipython_display(Tree.get_defaults())"
1265+
"ipython_display(dict(Tree.get_defaults()))"
12661266
]
12671267
},
12681268
{

lale/grammar.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
Operator,
99
OperatorChoice,
1010
PlannedOperator,
11+
clone_op,
1112
make_choice,
1213
make_pipeline,
1314
make_pipeline_graph,
1415
)
15-
from lale.sklearn_compat import clone_op
1616

1717

1818
class NonTerminal(Operator):
@@ -24,6 +24,25 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
2424
out["name"] = self._name
2525
return out
2626

27+
def _with_params(self, try_mutate: bool, **impl_params) -> Operator:
28+
"""
29+
This method updates the parameters of the operator. NonTerminals do not support
30+
in-place mutation
31+
"""
32+
known_keys = set(["name"])
33+
if impl_params:
34+
new_keys = set(impl_params.keys())
35+
if not new_keys.issubset(known_keys):
36+
unknowns = {k: v for k, v in impl_params.items() if k not in known_keys}
37+
raise ValueError(
38+
f"NonTerminal._with_params called with unknown parameters: {unknowns}"
39+
)
40+
else:
41+
assert "name" in impl_params
42+
return NonTerminal(impl_params["name"])
43+
else:
44+
return self
45+
2746
def __init__(self, name):
2847
self._name = name
2948

@@ -56,8 +75,19 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
5675
out = {}
5776
out["variables"] = self.variables
5877
# todo: support deep=True
78+
# just like a higher order operator does
5979
return out
6080

81+
def _with_params(self, try_mutate: bool, **impl_params) -> Operator:
82+
"""
83+
This method updates the parameters of the operator.
84+
If try_mutate is set, it will attempt to update the operator in place
85+
this may not always be possible
86+
"""
87+
# TODO implement support
88+
# from this point of view, Grammar is just a higher order operator
89+
raise NotImplementedError("setting Grammar parameters is not yet supported")
90+
6191
def __init__(self, variables: Dict[str, Operator] = {}):
6292
self._variables = variables
6393

lale/helpers.py

+153-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,18 @@
2121
import sys
2222
import time
2323
import traceback
24-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
24+
from typing import (
25+
TYPE_CHECKING,
26+
Any,
27+
Dict,
28+
Iterable,
29+
List,
30+
Mapping,
31+
Optional,
32+
Tuple,
33+
TypeVar,
34+
Union,
35+
)
2536

2637
import h5py
2738
import jsonschema
@@ -434,8 +445,9 @@ def create_instance_from_hyperopt_search_space(lale_object, hyperparams):
434445

435446
if isinstance(lale_object, PlannedIndividualOp):
436447
new_hyperparams: Dict[str, Any] = dict_without(hyperparams, "name")
437-
if lale_object._hyperparams is not None:
438-
obj_hyperparams = dict(lale_object._hyperparams)
448+
hps = lale_object.hyperparams()
449+
if hps is not None:
450+
obj_hyperparams = dict(hps)
439451
else:
440452
obj_hyperparams = {}
441453

@@ -548,7 +560,7 @@ def get_equivalent_lale_op(sklearn_obj, fitted):
548560
lale_op = class_
549561
else:
550562
lale_op = lale.operators.TrainedIndividualOp(
551-
class_._name, class_._impl, class_._schemas, None
563+
class_._name, class_._impl, class_._schemas, None, _lale_trained=True
552564
)
553565

554566
try:
@@ -767,3 +779,140 @@ def add_missing_values(orig_X, missing_rate=0.1, seed=None):
767779
i_missing_sample += 1
768780
missing_X.iloc[i_sample, i_feature] = np.nan
769781
return missing_X
782+
783+
784+
# helpers for manipulating (extended) sklearn style paths.
785+
# documentation of the path format is part of the operators module docstring
786+
787+
788+
def partition_sklearn_params(
789+
d: Dict[str, Any]
790+
) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
791+
sub_parts: Dict[str, Dict[str, Any]] = {}
792+
main_parts: Dict[str, Any] = {}
793+
794+
for k, v in d.items():
795+
ks = k.split("__", 1)
796+
if len(ks) == 1:
797+
assert k not in main_parts
798+
main_parts[k] = v
799+
else:
800+
assert len(ks) == 2
801+
bucket: Dict[str, Any] = {}
802+
group: str = ks[0]
803+
param: str = ks[1]
804+
if group in sub_parts:
805+
bucket = sub_parts[group]
806+
else:
807+
sub_parts[group] = bucket
808+
assert param not in bucket
809+
bucket[param] = v
810+
return (main_parts, sub_parts)
811+
812+
813+
def partition_sklearn_choice_params(d: Dict[str, Any]) -> Tuple[int, Dict[str, Any]]:
814+
discriminant_value: int = -1
815+
choice_parts: Dict[str, Any] = {}
816+
817+
for k, v in d.items():
818+
if k == discriminant_name:
819+
assert discriminant_value == -1
820+
discriminant_value = int(v)
821+
else:
822+
k_rest = unnest_choice(k)
823+
choice_parts[k_rest] = v
824+
assert discriminant_value != -1
825+
return (discriminant_value, choice_parts)
826+
827+
828+
DUMMY_SEARCH_SPACE_GRID_PARAM_NAME: str = "$"
829+
discriminant_name: str = "?"
830+
choice_prefix: str = "?"
831+
structure_type_name: str = "#"
832+
structure_type_list: str = "list"
833+
structure_type_tuple: str = "tuple"
834+
structure_type_dict: str = "dict"
835+
836+
837+
def get_name_and_index(name: str) -> Tuple[str, int]:
838+
""" given a name of the form "name@i", returns (name, i)
839+
if given a name of the form "name", returns (name, 0)
840+
"""
841+
splits = name.split("@", 1)
842+
if len(splits) == 1:
843+
return splits[0], 0
844+
else:
845+
return splits[0], int(splits[1])
846+
847+
848+
def make_degen_indexed_name(name, index):
849+
return f"{name}@{index}"
850+
851+
852+
def make_indexed_name(name, index):
853+
if index == 0:
854+
return name
855+
else:
856+
return f"{name}@{index}"
857+
858+
859+
def make_array_index_name(index, is_tuple: bool = False):
860+
sep = "##" if is_tuple else "#"
861+
return f"{sep}{str(index)}"
862+
863+
864+
def is_numeric_structure(structure_type: str):
865+
866+
if structure_type == "list" or structure_type == "tuple":
867+
return True
868+
elif structure_type == "dict":
869+
return False
870+
else:
871+
assert False, f"Unknown structure type {structure_type} found"
872+
873+
874+
V = TypeVar("V")
875+
876+
877+
def nest_HPparam(name: str, key: str):
878+
if key == DUMMY_SEARCH_SPACE_GRID_PARAM_NAME:
879+
# we can get rid of the dummy now, since we have a name for it
880+
return name
881+
return name + "__" + key
882+
883+
884+
def nest_HPparams(name: str, grid: Mapping[str, V]) -> Dict[str, V]:
885+
return {(nest_HPparam(name, k)): v for k, v in grid.items()}
886+
887+
888+
def nest_all_HPparams(
889+
name: str, grids: Iterable[Mapping[str, V]]
890+
) -> List[Dict[str, V]]:
891+
""" Given the name of an operator in a pipeline, this transforms every key(parameter name) in the grids
892+
to use the operator name as a prefix (separated by __). This is the convention in scikit-learn pipelines.
893+
"""
894+
return [nest_HPparams(name, grid) for grid in grids]
895+
896+
897+
def nest_choice_HPparam(key: str):
898+
return choice_prefix + key
899+
900+
901+
def nest_choice_HPparams(grid: Mapping[str, V]) -> Dict[str, V]:
902+
return {(nest_choice_HPparam(k)): v for k, v in grid.items()}
903+
904+
905+
def nest_choice_all_HPparams(grids: Iterable[Mapping[str, V]]) -> List[Dict[str, V]]:
906+
""" this transforms every key(parameter name) in the grids
907+
to be nested under a choice, using a ? as a prefix (separated by __). This is the convention in scikit-learn pipelines.
908+
"""
909+
return [nest_choice_HPparams(grid) for grid in grids]
910+
911+
912+
def unnest_choice(k: str) -> str:
913+
assert k.startswith(choice_prefix)
914+
return k[len(choice_prefix) :]
915+
916+
917+
def unnest_HPparams(k: str) -> List[str]:
918+
return k.split("__")

lale/json_operator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -391,12 +391,12 @@ def _op_to_json_rec(
391391
if isinstance(op, lale.operators.TrainableIndividualOp):
392392
if hasattr(op._impl, "viz_label"):
393393
jsn["viz_label"] = op._impl.viz_label()
394-
if op.hyperparams() is None:
394+
if op.reduced_hyperparams() is None:
395395
jsn["hyperparams"] = None
396396
else:
397397
steps: Dict[str, JSON_TYPE] = {}
398398
jsn["hyperparams"] = _hps_to_json_rec(
399-
op.hyperparams(), cls2label, gensym, steps
399+
op.reduced_hyperparams(), cls2label, gensym, steps
400400
)
401401
if len(steps) > 0:
402402
jsn["steps"] = steps

lale/lib/imblearn/base_resampler.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from lale.sklearn_compat import make_sklearn_compat_opt
16-
1715

1816
class BaseResamplerImpl:
1917
def __init__(self, operator=None, resampler=None):
20-
self.operator = make_sklearn_compat_opt(operator)
18+
self.operator = operator
2119
self.resampler = resampler
2220

2321
def fit(self, X, y=None):

lale/lib/lale/grid_search_cv.py

+35-15
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,14 @@ def fit(self, X, y):
7171

7272
observed_op = op
7373
obs = self._hyperparams["observer"]
74-
if obs is not None:
75-
observed_op = Observing(op=op, observer=obs)
74+
# We always create an observer.
75+
# Otherwise, we can have a problem with PlannedOperators
76+
# (that are not trainable):
77+
# GridSearchCV checks if a fit method is present before
78+
# configuring the operator, and our planned operators
79+
# don't have a fit method
80+
# Observing always has a fit method, and so solves this problem.
81+
observed_op = Observing(op=op, observer=obs)
7682

7783
hp_grid = self._hyperparams["hp_grid"]
7884
data_schema = lale.helpers.fold_schema(
@@ -86,14 +92,24 @@ def fit(self, X, y):
8692
pgo=self._hyperparams["pgo"],
8793
data_schema=data_schema,
8894
)
95+
else:
96+
# if hp_grid is specified manually, we need to add a level of nesting
97+
# since we are wrapping it in an observer
98+
if isinstance(hp_grid, list):
99+
hp_grid = lale.helpers.nest_all_HPparams("op", hp_grid)
100+
else:
101+
assert isinstance(hp_grid, dict)
102+
hp_grid = lale.helpers.nest_HPparams("op", hp_grid)
103+
89104
if not hp_grid and isinstance(op, lale.operators.IndividualOp):
90105
hp_grid = [
91-
lale.search.lale_grid_search_cv.get_defaults_as_param_grid(observed_op)
106+
lale.search.lale_grid_search_cv.get_defaults_as_param_grid(observed_op) # type: ignore
92107
]
93108
be: lale.operators.TrainableOperator
94109
if hp_grid:
95110
if obs is not None:
96-
observed_op._impl.startObserving(
111+
impl = observed_op._impl # type: ignore
112+
impl.startObserving(
97113
"optimize",
98114
hp_grid=hp_grid,
99115
op=op,
@@ -103,27 +119,30 @@ def fit(self, X, y):
103119
)
104120
try:
105121
self.grid = lale.search.lale_grid_search_cv.get_lale_gridsearchcv_op(
106-
lale.sklearn_compat.make_sklearn_compat(observed_op),
122+
observed_op,
107123
hp_grid,
108124
cv=self._hyperparams["cv"],
109125
verbose=self._hyperparams["verbose"],
110126
scoring=self._hyperparams["scoring"],
111127
n_jobs=self._hyperparams["n_jobs"],
112128
)
113129
self.grid.fit(X, y)
114-
be = self.grid.best_estimator_.to_lale()
130+
be = self.grid.best_estimator_
115131
except BaseException as e:
116132
if obs is not None:
117-
assert isinstance(observed_op._impl, ObservingImpl)
118-
observed_op._impl.failObserving("optimize", e)
133+
impl = observed_op._impl # type: ignore
134+
assert isinstance(impl, ObservingImpl)
135+
impl.failObserving("optimize", e)
119136
raise
120137

121-
if obs is not None:
122-
impl = getattr(be, "_impl")
123-
if impl is not None:
124-
assert isinstance(impl, ObservingImpl)
125-
be = impl.getOp()
126-
observed_op._impl.endObserving("optimize", best=be)
138+
impl = getattr(be, "_impl", None)
139+
if impl is not None:
140+
assert isinstance(impl, ObservingImpl)
141+
be = impl.getOp()
142+
if obs is not None:
143+
obs_impl = observed_op._impl # type: ignore
144+
145+
obs_impl.endObserving("optimize", best=be)
127146
else:
128147
assert isinstance(op, lale.operators.TrainableOperator)
129148
be = op
@@ -142,7 +161,8 @@ def get_pipeline(self, pipeline_name=None, astype="lale"):
142161
if result is None or astype == "lale":
143162
return result
144163
assert astype == "sklearn", astype
145-
return lale.sklearn_compat.make_sklearn_compat(result)
164+
# TODO: should this try and return an actual sklearn pipeline?
165+
return result
146166

147167

148168
_hyperparams_schema = {

lale/lib/lale/smac.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ def get_pipeline(self, pipeline_name=None, astype="lale"):
217217
if result is None or astype == "lale":
218218
return result
219219
assert astype == "sklearn", astype
220-
return lale.sklearn_compat.make_sklearn_compat(result)
220+
# TODO: should this try and return an actual sklearn pipeline?
221+
return result
221222

222223

223224
_hyperparams_schema = {

0 commit comments

Comments
 (0)