Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Orange/modelling/ada_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from Orange.ensembles import (
SklAdaBoostClassificationLearner, SklAdaBoostRegressionLearner
)
from Orange.modelling import Fitter
from Orange.modelling import SklFitter

__all__ = ['SklAdaBoostLearner']


class SklAdaBoostLearner(Fitter):
class SklAdaBoostLearner(SklFitter):
__fits__ = {'classification': SklAdaBoostClassificationLearner,
'regression': SklAdaBoostRegressionLearner}

Expand Down
30 changes: 14 additions & 16 deletions Orange/modelling/base.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
from Orange.base import Learner, Model
import numpy as np

from Orange.base import Learner, Model, SklLearner

class FitterMeta(type):
"""Ensure that each subclass of the `Fitter` class overrides the `__fits__`
attribute with a valid value."""
def __new__(mcs, name, bases, attrs):
# Check that a fitter implementation defines a valid `__fits__`
if any(cls.__name__ == 'Fitter' for cls in bases):
fits = attrs.get('__fits__')
assert isinstance(fits, dict), '__fits__ must be dict instance'
assert fits.get('classification') and fits.get('regression'), \
('`__fits__` property does not define classification '
'or regression learner. Use a simple learner if you don\'t '
'need the functionality provided by Fitter.')
return super().__new__(mcs, name, bases, attrs)


class Fitter(Learner, metaclass=FitterMeta):
class Fitter(Learner):
"""Handle multiple types of target variable with one learner.

Subclasses of this class serve as a sort of dispatcher. When subclassing,
Expand Down Expand Up @@ -118,3 +105,14 @@ def params(self):
def get_params(self, problem_type):
"""Access the specific learner params of a given learner."""
return self.get_learner(problem_type).params


class SklFitter(Fitter):
def _fit_model(self, data):
model = super()._fit_model(data)
model.used_vals = [np.unique(y) for y in data.Y[:, None].T]
if data.domain.has_discrete_class:
model.params = self.get_params(self.CLASSIFICATION)
else:
model.params = self.get_params(self.REGRESSION)
return model
4 changes: 2 additions & 2 deletions Orange/modelling/knn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from Orange.classification import KNNLearner as KNNClassification
from Orange.modelling import Fitter
from Orange.modelling import SklFitter
from Orange.regression import KNNRegressionLearner

__all__ = ['KNNLearner']


class KNNLearner(Fitter):
class KNNLearner(SklFitter):
__fits__ = {'classification': KNNClassification,
'regression': KNNRegressionLearner}
4 changes: 2 additions & 2 deletions Orange/modelling/linear.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from Orange.classification.sgd import SGDClassificationLearner
from Orange.modelling import Fitter
from Orange.modelling import SklFitter
from Orange.regression import SGDRegressionLearner

__all__ = ['SGDLearner']


class SGDLearner(Fitter):
class SGDLearner(SklFitter):
name = 'sgd'

__fits__ = {'classification': SGDClassificationLearner,
Expand Down
4 changes: 2 additions & 2 deletions Orange/modelling/neural_network.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from Orange.classification import NNClassificationLearner
from Orange.modelling import Fitter
from Orange.modelling import SklFitter
from Orange.regression import NNRegressionLearner

__all__ = ['NNLearner']


class NNLearner(Fitter):
class NNLearner(SklFitter):
__fits__ = {'classification': NNClassificationLearner,
'regression': NNRegressionLearner}
4 changes: 2 additions & 2 deletions Orange/modelling/randomforest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from Orange.base import RandomForestModel
from Orange.classification import RandomForestLearner as RFClassification
from Orange.modelling import Fitter
from Orange.modelling import SklFitter
from Orange.regression import RandomForestRegressionLearner as RFRegression

__all__ = ['RandomForestLearner']


class RandomForestLearner(Fitter):
class RandomForestLearner(SklFitter):
name = 'random forest'

__fits__ = {'classification': RFClassification,
Expand Down
8 changes: 4 additions & 4 deletions Orange/modelling/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
LinearSVMLearner as LinearSVCLearner,
NuSVMLearner as NuSVCLearner,
)
from Orange.modelling import Fitter
from Orange.modelling import SklFitter
from Orange.regression import SVRLearner, LinearSVRLearner, NuSVRLearner

__all__ = ['SVMLearner', 'LinearSVMLearner', 'NuSVMLearner']


class SVMLearner(Fitter):
class SVMLearner(SklFitter):
__fits__ = {'classification': SVCLearner, 'regression': SVRLearner}


class LinearSVMLearner(Fitter):
class LinearSVMLearner(SklFitter):
__fits__ = {'classification': LinearSVCLearner, 'regression': LinearSVRLearner}


class NuSVMLearner(Fitter):
class NuSVMLearner(SklFitter):
__fits__ = {'classification': NuSVCLearner, 'regression': NuSVRLearner}
4 changes: 2 additions & 2 deletions Orange/modelling/tree.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from Orange.classification import SklTreeLearner
from Orange.classification import TreeLearner as ClassificationTreeLearner
from Orange.modelling import Fitter
from Orange.modelling import Fitter, SklFitter
from Orange.regression import TreeLearner as RegressionTreeLearner
from Orange.regression.tree import SklTreeRegressionLearner
from Orange.tree import TreeModel

__all__ = ['SklTreeLearner', 'TreeLearner']


class SklTreeLearner(Fitter):
class SklTreeLearner(SklFitter):
name = 'tree'

__fits__ = {'classification': SklTreeLearner,
Expand Down
19 changes: 0 additions & 19 deletions Orange/tests/test_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,6 @@ def setUpClass(cls):
cls.heart_disease = Table('heart_disease')
cls.housing = Table('housing')

def test_throws_if_fits_property_is_invalid(self):
"""The `__fits__` attribute must be an instance of `LearnerTypes`"""
with self.assertRaises(Exception):
class DummyFitter(Fitter):
name = 'dummy'
__fits__ = (DummyClassificationLearner, DummyRegressionLearner)

fitter = DummyFitter()
fitter(self.heart_disease)

def test_dispatches_to_correct_learner(self):
"""Based on the input data, it should dispatch the fitting process to
the appropriate learner"""
Expand Down Expand Up @@ -102,15 +92,6 @@ class DummyFitter(Fitter):
except TypeError:
self.fail('Fitter did not properly distribute params to learners')

def test_error_for_data_type_with_no_learner(self):
"""If we attempt to define a fitter which only handles one data type
it makes more sense to simply use a Learner."""
with self.assertRaises(AssertionError):
class DummyFitter(Fitter):
name = 'dummy'
__fits__ = {'classification': None,
'regression': DummyRegressionLearner}

def test_correctly_sets_preprocessors_on_learner(self):
"""Fitters have to be able to pass the `use_default_preprocessors` and
preprocessors down to individual learners"""
Expand Down