Skip to content

Commit d3dc522

Browse files
committed
learner adequacy check refactor
1 parent 4801feb commit d3dc522

File tree

6 files changed

+22
-23
lines changed

6 files changed

+22
-23
lines changed

Orange/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Iterable
44
import re
55
import warnings
6-
from typing import Callable, Dict
6+
from typing import Callable, Dict, Tuple
77

88
import numpy as np
99
import scipy
@@ -86,7 +86,6 @@ class Learner(ReprableWithPreprocessors):
8686
#: A sequence of data preprocessors to apply on data prior to
8787
#: fitting the model
8888
preprocessors = ()
89-
learner_adequacy_err_msg = ''
9089

9190
def __init__(self, preprocessors=None):
9291
self.use_default_preprocessors = False
@@ -106,8 +105,9 @@ def fit_storage(self, data):
106105
return self.fit(X, Y, W)
107106

108107
def __call__(self, data, progress_callback=None):
109-
if not self.check_learner_adequacy(data.domain):
110-
raise ValueError(self.learner_adequacy_err_msg)
108+
learner_is_adequate, err_msg = self.check_learner_adequacy(data.domain)
109+
if not learner_is_adequate:
110+
raise ValueError(err_msg)
111111

112112
origdomain = data.domain
113113

@@ -173,8 +173,8 @@ def active_preprocessors(self):
173173
self.preprocessors is not type(self).preprocessors):
174174
yield from type(self).preprocessors
175175

176-
def check_learner_adequacy(self, _):
177-
return True
176+
def check_learner_adequacy(self, _) -> Tuple[bool, str]:
177+
return True, ""
178178

179179
@property
180180
def name(self):

Orange/classification/base_classification.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@
77
class LearnerClassification(Learner):
88

99
def check_learner_adequacy(self, domain):
10-
is_adequate = True
10+
err_msg = ""
1111
if len(domain.class_vars) > 1:
12-
is_adequate = False
13-
self.learner_adequacy_err_msg = "Too many target variables."
12+
err_msg = "Too many target variables."
1413
elif not domain.has_discrete_class:
15-
is_adequate = False
16-
self.learner_adequacy_err_msg = "Categorical class variable expected."
17-
return is_adequate
14+
err_msg = "Categorical class variable expected."
15+
return not err_msg, err_msg
1816

1917

2018
class ModelClassification(Model):

Orange/preprocess/impute.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def __call__(self, data, variable):
224224
variable = data.domain[variable]
225225
domain = domain_with_class_var(data.domain, variable)
226226

227-
if self.learner.check_learner_adequacy(domain):
227+
learner_is_adequate, err_msg = self.learner.check_learner_adequacy(data.domain)
228+
if learner_is_adequate:
228229
data = data.transform(domain)
229230
model = self.learner(data)
230231
assert model.domain.class_var == variable
@@ -239,7 +240,8 @@ def copy(self):
239240

240241
def supports_variable(self, variable):
241242
domain = Orange.data.Domain([], class_vars=variable)
242-
return self.learner.check_learner_adequacy(domain)
243+
learner_is_adequate, _ = self.learner.check_learner_adequacy(domain)
244+
return learner_is_adequate
243245

244246

245247
def domain_with_class_var(domain, class_var):

Orange/regression/base_regression.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@
77
class LearnerRegression(Learner):
88

99
def check_learner_adequacy(self, domain):
10-
is_adequate = True
10+
err_msg = ""
1111
if len(domain.class_vars) > 1:
12-
is_adequate = False
13-
self.learner_adequacy_err_msg = "Too many target variables."
12+
err_msg = "Too many target variables."
1413
elif not domain.has_continuous_class:
15-
is_adequate = False
16-
self.learner_adequacy_err_msg = "Numeric class variable expected."
17-
return is_adequate
14+
err_msg = "Numeric class variable expected."
15+
return not err_msg, err_msg
1816

1917

2018
class ModelRegression(Model):

Orange/tests/dummy_learners.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class DummyMulticlassLearner(SklLearner):
3434
supports_multiclass = True
3535

3636
def check_learner_adequacy(self, domain):
37-
return all(c.is_discrete for c in domain.class_vars)
37+
return all(c.is_discrete for c in domain.class_vars), ''
3838

3939
def fit(self, X, Y, W):
4040
rows, class_vars = Y.shape

Orange/widgets/utils/owlearnerwidget.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,9 @@ def check_data(self):
246246
self.Error.sparse_not_supported.clear()
247247
if self.data is not None and self.learner is not None:
248248
self.Error.data_error.clear()
249-
if not self.learner.check_learner_adequacy(self.data.domain):
250-
self.Error.data_error(self.learner.learner_adequacy_err_msg)
249+
learner_is_adequate, err_msg = self.learner.check_learner_adequacy(self.data.domain)
250+
if not learner_is_adequate:
251+
self.Error.data_error(err_msg)
251252
elif not len(self.data):
252253
self.Error.data_error("Dataset is empty.")
253254
elif len(ut.unique(self.data.Y)) < 2:

0 commit comments

Comments
 (0)