|
1 | | -from Orange.base import Learner, Model |
| 1 | +import numpy as np |
2 | 2 |
|
| 3 | +from Orange.base import Learner, Model, SklLearner |
3 | 4 |
|
4 | | -class FitterMeta(type): |
5 | | - """Ensure that each subclass of the `Fitter` class overrides the `__fits__` |
6 | | - attribute with a valid value.""" |
7 | | - def __new__(mcs, name, bases, attrs): |
8 | | - # Check that a fitter implementation defines a valid `__fits__` |
9 | | - if any(cls.__name__ == 'Fitter' for cls in bases): |
10 | | - fits = attrs.get('__fits__') |
11 | | - assert isinstance(fits, dict), '__fits__ must be dict instance' |
12 | | - assert fits.get('classification') and fits.get('regression'), \ |
13 | | - ('`__fits__` property does not define classification ' |
14 | | - 'or regression learner. Use a simple learner if you don\'t ' |
15 | | - 'need the functionality provided by Fitter.') |
16 | | - return super().__new__(mcs, name, bases, attrs) |
17 | 5 |
|
18 | | - |
19 | | -class Fitter(Learner, metaclass=FitterMeta): |
| 6 | +class Fitter(Learner): |
20 | 7 | """Handle multiple types of target variable with one learner. |
21 | 8 |
|
22 | 9 | Subclasses of this class serve as a sort of dispatcher. When subclassing, |
@@ -119,3 +106,14 @@ def params(self): |
119 | 106 | def get_params(self, problem_type): |
120 | 107 | """Access the specific learner params of a given learner.""" |
121 | 108 | return self.get_learner(problem_type).params |
| 109 | + |
| 110 | + |
| 111 | +class SklFitter(Fitter): |
| 112 | + def _fit_model(self, data): |
| 113 | + model = super()._fit_model(data) |
| 114 | + model.used_vals = [np.unique(y) for y in data.Y[:, None].T] |
| 115 | + if data.domain.has_discrete_class: |
| 116 | + model.params = self.get_params(self.CLASSIFICATION) |
| 117 | + else: |
| 118 | + model.params = self.get_params(self.REGRESSION) |
| 119 | + return model |
0 commit comments