Skip to content

Commit 04cf2e3

Browse files
committed
add multi_target_input tests
1 parent fedb67e commit 04cf2e3

File tree

2 files changed

+65
-6
lines changed

2 files changed

+65
-6
lines changed

Orange/widgets/evaluate/tests/test_owpredictions.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def set_input(data, model):
188188
(self.widget.Inputs.data, data),
189189
(self.widget.Inputs.predictors, model)
190190
])
191+
191192
iris = self.iris
192193
learner = ConstantLearner()
193194
heart_disease = Table("heart_disease")
@@ -253,6 +254,7 @@ def test_sort_predictions(self):
253254
"""
254255
Test whether sorting of probabilities by FilterSortProxy is correct.
255256
"""
257+
256258
def get_items_order(model):
257259
return model.mapToSourceRows(np.arange(model.rowCount()))
258260

@@ -878,6 +880,27 @@ def test_change_target(self):
878880
self.assertEqual(float(table.model.data(table.model.index(0, 3))),
879881
idx)
880882

883+
def test_multi_target_input(self):
884+
widget = self.widget
885+
886+
domain = Domain([ContinuousVariable('var1')],
887+
class_vars=[
888+
ContinuousVariable('c1'),
889+
DiscreteVariable('c2', values=('no', 'yes'))
890+
])
891+
data = Table.from_list(domain, [[1, 5, 0], [2, 10, 1]])
892+
893+
mock_model = Mock(spec=Model, return_value=np.asarray([0.2, 0.1]))
894+
mock_model.name = 'Mockery'
895+
mock_model.domain = domain
896+
mock_learner = Mock(return_value=mock_model)
897+
model = mock_learner(data)
898+
899+
self.send_signal(widget.Inputs.data, data)
900+
self.send_signal(widget.Inputs.predictors, model, 1)
901+
pred = self.get_output(widget.Outputs.predictions)
902+
self.assertIsInstance(pred, Table)
903+
881904
def test_report(self):
882905
widget = self.widget
883906

@@ -1022,7 +1045,6 @@ def assert_called(exp_selected, exp_deselected):
10221045
self.assertEqual(list(selected), exp_selected)
10231046
self.assertEqual(list(deselected), exp_deselected)
10241047

1025-
10261048
store.model.setSortIndices([4, 0, 1, 2, 3])
10271049
store.select_rows({3, 4}, QItemSelectionModel.Select)
10281050
assert_called([4, 0], [])
@@ -1132,7 +1154,7 @@ def setUpClass(cls) -> None:
11321154
cls.probs = [np.array([[80, 10, 10],
11331155
[30, 70, 0],
11341156
[15, 80, 5],
1135-
[0, 10, 90],
1157+
[0, 10, 90],
11361158
[55, 40, 5]]) / 100,
11371159
np.array([[80, 0, 20],
11381160
[90, 5, 5],

Orange/widgets/evaluate/tests/test_owtestandscore.py

+41-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from Orange.evaluation import Results, TestOnTestData, scoring
1717
from Orange.evaluation.scoring import ClassificationScore, RegressionScore, \
1818
Score
19-
from Orange.base import Learner
19+
from Orange.base import Learner, Model
2020
from Orange.modelling import ConstantLearner
2121
from Orange.regression import MeanLearner
2222
from Orange.widgets.evaluate.owtestandscore import (
@@ -178,7 +178,7 @@ def test_one_class_value(self):
178178
table = Table.from_list(
179179
Domain(
180180
[ContinuousVariable("a"), ContinuousVariable("b")],
181-
[DiscreteVariable("c", values=("y", ))]),
181+
[DiscreteVariable("c", values=("y",))]),
182182
list(zip(
183183
[42.48, 16.84, 15.23, 23.8],
184184
[1., 2., 3., 4.],
@@ -192,6 +192,7 @@ def test_one_class_value(self):
192192

193193
def test_data_errors(self):
194194
""" Test all data_errors """
195+
195196
def assertErrorShown(data, is_shown, message):
196197
self.send_signal("Data", data)
197198
self.assertEqual(is_shown, self.widget.Error.train_data_error.is_shown())
@@ -378,7 +379,7 @@ def test_scores_log_reg_overfitted(self):
378379
self.assertTupleEqual(self._test_scores(
379380
table, table, LogisticRegressionLearner(),
380381
OWTestAndScore.TestOnTest, None),
381-
(1, 1, 1, 1, 1))
382+
(1, 1, 1, 1, 1))
382383

383384
def test_scores_log_reg_bad(self):
384385
table_train = Table.from_list(
@@ -393,7 +394,7 @@ def test_scores_log_reg_bad(self):
393394
self.assertTupleEqual(self._test_scores(
394395
table_train, table_test, LogisticRegressionLearner(),
395396
OWTestAndScore.TestOnTest, None),
396-
(0, 0, 0, 0, 0))
397+
(0, 0, 0, 0, 0))
397398

398399
def test_scores_log_reg_bad2(self):
399400
table_train = Table.from_list(
@@ -724,6 +725,42 @@ def test_copy_to_clipboard(self):
724725
for i in (0, 3, 4, 5, 6, 7)]) + "\r\n"
725726
self.assertEqual(clipboard_text, view_text)
726727

728+
def test_multi_target_input(self):
729+
class NewScorer(Score):
730+
class_types = (
731+
ContinuousVariable,
732+
DiscreteVariable,
733+
)
734+
735+
@staticmethod
736+
def is_compatible(domain: Domain) -> bool:
737+
return True
738+
739+
def compute_score(self, results):
740+
return [0.75]
741+
742+
domain = Domain([ContinuousVariable('var1')],
743+
class_vars=[
744+
ContinuousVariable('c1'),
745+
DiscreteVariable('c2', values=('no', 'yes'))
746+
])
747+
data = Table.from_list(domain, [[1, 5, 0], [2, 10, 1], [2, 10, 1]])
748+
749+
mock_model = Mock(spec=Model, return_value=np.asarray([[0.2, 0.1, 0.2]]))
750+
mock_model.name = 'Mockery'
751+
mock_model.domain = domain
752+
mock_learner = Mock(spec=Learner, return_value=mock_model)
753+
mock_learner.name = 'Mockery'
754+
755+
self.widget.resampling = OWTestAndScore.TestOnTrain
756+
self.send_signal(self.widget.Inputs.train_data, data)
757+
self.send_signal(self.widget.Inputs.learner, MajorityLearner(), 0)
758+
self.send_signal(self.widget.Inputs.learner, mock_learner, 1)
759+
_ = self.get_output(self.widget.Outputs.evaluations_results, wait=5000)
760+
self.assertTrue(len(self.widget.scorers) == 1)
761+
self.assertTrue(NewScorer in self.widget.scorers)
762+
self.assertTrue(len(self.widget._successful_slots()) == 1)
763+
727764

728765
class TestHelpers(unittest.TestCase):
729766
def test_results_one_vs_rest(self):

0 commit comments

Comments
 (0)