Skip to content

Commit 99ca084

Browse files
authored
Merge pull request #2138 from pavlin-policar/fitter-sklearn-quick
[FIX] Fitter: Fix used_vals and params not being set
2 parents 407db85 + efab929 commit 99ca084

File tree

9 files changed

+30
-51
lines changed

9 files changed

+30
-51
lines changed

Orange/modelling/ada_boost.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from Orange.ensembles import (
33
SklAdaBoostClassificationLearner, SklAdaBoostRegressionLearner
44
)
5-
from Orange.modelling import Fitter
5+
from Orange.modelling import SklFitter
66

77
__all__ = ['SklAdaBoostLearner']
88

99

10-
class SklAdaBoostLearner(Fitter):
10+
class SklAdaBoostLearner(SklFitter):
1111
__fits__ = {'classification': SklAdaBoostClassificationLearner,
1212
'regression': SklAdaBoostRegressionLearner}
1313

Orange/modelling/base.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,9 @@
1-
from Orange.base import Learner, Model
1+
import numpy as np
22

3+
from Orange.base import Learner, Model, SklLearner
34

4-
class FitterMeta(type):
5-
"""Ensure that each subclass of the `Fitter` class overrides the `__fits__`
6-
attribute with a valid value."""
7-
def __new__(mcs, name, bases, attrs):
8-
# Check that a fitter implementation defines a valid `__fits__`
9-
if any(cls.__name__ == 'Fitter' for cls in bases):
10-
fits = attrs.get('__fits__')
11-
assert isinstance(fits, dict), '__fits__ must be dict instance'
12-
assert fits.get('classification') and fits.get('regression'), \
13-
('`__fits__` property does not define classification '
14-
'or regression learner. Use a simple learner if you don\'t '
15-
'need the functionality provided by Fitter.')
16-
return super().__new__(mcs, name, bases, attrs)
175

18-
19-
class Fitter(Learner, metaclass=FitterMeta):
6+
class Fitter(Learner):
207
"""Handle multiple types of target variable with one learner.
218
229
Subclasses of this class serve as a sort of dispatcher. When subclassing,
@@ -119,3 +106,14 @@ def params(self):
119106
def get_params(self, problem_type):
120107
"""Access the specific learner params of a given learner."""
121108
return self.get_learner(problem_type).params
109+
110+
111+
class SklFitter(Fitter):
112+
def _fit_model(self, data):
113+
model = super()._fit_model(data)
114+
model.used_vals = [np.unique(y) for y in data.Y[:, None].T]
115+
if data.domain.has_discrete_class:
116+
model.params = self.get_params(self.CLASSIFICATION)
117+
else:
118+
model.params = self.get_params(self.REGRESSION)
119+
return model

Orange/modelling/knn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from Orange.classification import KNNLearner as KNNClassification
2-
from Orange.modelling import Fitter
2+
from Orange.modelling import SklFitter
33
from Orange.regression import KNNRegressionLearner
44

55
__all__ = ['KNNLearner']
66

77

8-
class KNNLearner(Fitter):
8+
class KNNLearner(SklFitter):
99
__fits__ = {'classification': KNNClassification,
1010
'regression': KNNRegressionLearner}

Orange/modelling/linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from Orange.classification.sgd import SGDClassificationLearner
2-
from Orange.modelling import Fitter
2+
from Orange.modelling import SklFitter
33
from Orange.regression import SGDRegressionLearner
44

55
__all__ = ['SGDLearner']
66

77

8-
class SGDLearner(Fitter):
8+
class SGDLearner(SklFitter):
99
name = 'sgd'
1010

1111
__fits__ = {'classification': SGDClassificationLearner,

Orange/modelling/neural_network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from Orange.classification import NNClassificationLearner
2-
from Orange.modelling import Fitter
2+
from Orange.modelling import SklFitter
33
from Orange.regression import NNRegressionLearner
44

55
__all__ = ['NNLearner']
66

77

8-
class NNLearner(Fitter):
8+
class NNLearner(SklFitter):
99
__fits__ = {'classification': NNClassificationLearner,
1010
'regression': NNRegressionLearner}

Orange/modelling/randomforest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from Orange.base import RandomForestModel
22
from Orange.classification import RandomForestLearner as RFClassification
3-
from Orange.modelling import Fitter
3+
from Orange.modelling import SklFitter
44
from Orange.regression import RandomForestRegressionLearner as RFRegression
55

66
__all__ = ['RandomForestLearner']
77

88

9-
class RandomForestLearner(Fitter):
9+
class RandomForestLearner(SklFitter):
1010
name = 'random forest'
1111

1212
__fits__ = {'classification': RFClassification,

Orange/modelling/svm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
LinearSVMLearner as LinearSVCLearner,
44
NuSVMLearner as NuSVCLearner,
55
)
6-
from Orange.modelling import Fitter
6+
from Orange.modelling import SklFitter
77
from Orange.regression import SVRLearner, LinearSVRLearner, NuSVRLearner
88

99
__all__ = ['SVMLearner', 'LinearSVMLearner', 'NuSVMLearner']
1010

1111

12-
class SVMLearner(Fitter):
12+
class SVMLearner(SklFitter):
1313
__fits__ = {'classification': SVCLearner, 'regression': SVRLearner}
1414

1515

16-
class LinearSVMLearner(Fitter):
16+
class LinearSVMLearner(SklFitter):
1717
__fits__ = {'classification': LinearSVCLearner, 'regression': LinearSVRLearner}
1818

1919

20-
class NuSVMLearner(Fitter):
20+
class NuSVMLearner(SklFitter):
2121
__fits__ = {'classification': NuSVCLearner, 'regression': NuSVRLearner}

Orange/modelling/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from Orange.classification import SklTreeLearner
22
from Orange.classification import TreeLearner as ClassificationTreeLearner
3-
from Orange.modelling import Fitter
3+
from Orange.modelling import Fitter, SklFitter
44
from Orange.regression import TreeLearner as RegressionTreeLearner
55
from Orange.regression.tree import SklTreeRegressionLearner
66
from Orange.tree import TreeModel
77

88
__all__ = ['SklTreeLearner', 'TreeLearner']
99

1010

11-
class SklTreeLearner(Fitter):
11+
class SklTreeLearner(SklFitter):
1212
name = 'tree'
1313

1414
__fits__ = {'classification': SklTreeLearner,

Orange/tests/test_fitter.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,6 @@ def setUpClass(cls):
2929
cls.heart_disease = Table('heart_disease')
3030
cls.housing = Table('housing')
3131

32-
def test_throws_if_fits_property_is_invalid(self):
33-
"""The `__fits__` attribute must be an instance of `LearnerTypes`"""
34-
with self.assertRaises(Exception):
35-
class DummyFitter(Fitter):
36-
name = 'dummy'
37-
__fits__ = (DummyClassificationLearner, DummyRegressionLearner)
38-
39-
fitter = DummyFitter()
40-
fitter(self.heart_disease)
41-
4232
def test_dispatches_to_correct_learner(self):
4333
"""Based on the input data, it should dispatch the fitting process to
4434
the appropriate learner"""
@@ -102,15 +92,6 @@ class DummyFitter(Fitter):
10292
except TypeError:
10393
self.fail('Fitter did not properly distribute params to learners')
10494

105-
def test_error_for_data_type_with_no_learner(self):
106-
"""If we attempt to define a fitter which only handles one data type
107-
it makes more sense to simply use a Learner."""
108-
with self.assertRaises(AssertionError):
109-
class DummyFitter(Fitter):
110-
name = 'dummy'
111-
__fits__ = {'classification': None,
112-
'regression': DummyRegressionLearner}
113-
11495
def test_correctly_sets_preprocessors_on_learner(self):
11596
"""Fitters have to be able to pass the `use_default_preprocessors` and
11697
preprocessors down to individual learners"""

0 commit comments

Comments
 (0)