Skip to content

Commit 2010dbd

Browse files
authored
Remove mandatory dask_ml dependencies (#208)
* Remove mandatory dask_ml dependencies * Increase coverage
1 parent 0c3a6a1 commit 2010dbd

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

dask_sql/physical/rel/custom/create_experiment.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import dask.dataframe as dd
44
import pandas as pd
5-
from dask_ml.wrappers import ParallelPostFit
65

76
from dask_sql.datacontainer import ColumnContainer, DataContainer
87
from dask_sql.java import org
@@ -163,6 +162,13 @@ def convert(
163162
f"Can not import tuner {experiment_class}. Make sure you spelled it correctly and have installed all packages."
164163
)
165164

165+
try:
166+
from dask_ml.wrappers import ParallelPostFit
167+
except ImportError: # pragma: no cover
168+
raise ValueError(
169+
"dask_ml must be installed to use automl and tune hyperparameters"
170+
)
171+
166172
model = ModelClass()
167173

168174
search = ExperimentClass(model, {**parameters}, **experiment_kwargs)
@@ -186,6 +192,14 @@ def convert(
186192
raise ValueError(
187193
f"Can not import automl model {automl_class}. Make sure you spelled it correctly and have installed all packages."
188194
)
195+
196+
try:
197+
from dask_ml.wrappers import ParallelPostFit
198+
except ImportError: # pragma: no cover
199+
raise ValueError(
200+
"dask_ml must be installed to use automl and tune hyperparameters"
201+
)
202+
189203
automl = AutoMLClass(**automl_kwargs)
190204
# should be avoided if data doesn't fit in memory
191205
automl.fit(X.compute(), y.compute())

dask_sql/physical/rel/custom/create_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,18 @@ def convert(
139139

140140
model = ModelClass(**kwargs)
141141
if wrap_fit:
142-
from dask_ml.wrappers import Incremental
142+
try:
143+
from dask_ml.wrappers import Incremental
144+
except ImportError: # pragma: no cover
145+
raise ValueError("Wrapping requires dask-ml to be installed.")
143146

144147
model = Incremental(estimator=model)
145148

146149
if wrap_predict:
147-
from dask_ml.wrappers import ParallelPostFit
150+
try:
151+
from dask_ml.wrappers import ParallelPostFit
152+
except ImportError: # pragma: no cover
153+
raise ValueError("Wrapping requires dask-ml to be installed.")
148154

149155
model = ParallelPostFit(estimator=model)
150156

0 commit comments

Comments
 (0)