Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Orange/widgets/evaluate/owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ def _update_predictions_model(self):
for p in slots:
values, prob = p.results
if self.class_var.is_discrete:
# if values were added to class_var between building the
# model and predicting, add zeros for new class values,
# which are always at the end
prob = numpy.c_[prob,
numpy.zeros((prob.shape[0], len(class_var.values) - prob.shape[1]))]
values = [Value(class_var, v) for v in values]
results.append((values, prob))
results = list(zip(*(zip(*res) for res in results)))
Expand Down
71 changes: 61 additions & 10 deletions Orange/widgets/evaluate/tests/test_owpredictions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Tests for OWPredictions"""

import io
import numpy as np

from Orange.data.io import TabReader
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.evaluate.owpredictions import OWPredictions

from Orange.data import Table, Domain
from Orange.classification import MajorityLearner
from Orange.data import Table, Domain, Variable
from Orange.modelling import ConstantLearner, TreeLearner
from Orange.evaluation import Results
from Orange.widgets.tests.utils import excepthook_catch


class TestOWPredictions(WidgetTest):
Expand All @@ -27,7 +29,7 @@ def test_nan_target_input(self):
yvec, _ = data.get_column_view(data.domain.class_var)
nanmask = np.isnan(yvec)
self.send_signal("Data", data)
self.send_signal("Predictors", MajorityLearner()(data), 1)
self.send_signal("Predictors", ConstantLearner()(data), 1)
pred = self.get_output("Predictions", )
self.assertIsInstance(pred, Table)
np.testing.assert_array_equal(
Expand All @@ -50,8 +52,8 @@ def test_mismatching_targets(self):
error = self.widget.Error

titanic = Table("titanic")
majority_titanic = MajorityLearner()(titanic)
majority_iris = MajorityLearner()(self.iris)
majority_titanic = ConstantLearner()(titanic)
majority_iris = ConstantLearner()(self.iris)

self.send_signal("Data", self.iris)
self.send_signal("Predictors", majority_iris, 1)
Expand Down Expand Up @@ -101,14 +103,14 @@ def test_no_class_on_test(self):
error = self.widget.Error

titanic = Table("titanic")
majority_titanic = MajorityLearner()(titanic)
majority_iris = MajorityLearner()(self.iris)
majority_titanic = ConstantLearner()(titanic)
majority_iris = ConstantLearner()(self.iris)

no_class = Table(Domain(titanic.domain.attributes, None), titanic)
self.send_signal("Predictors", majority_titanic, 1)
self.send_signal("Data", no_class)
out = self.get_output("Predictions")
np.testing.assert_allclose(out.get_column_view("majority")[0], 0)
np.testing.assert_allclose(out.get_column_view("constant")[0], 0)

self.send_signal("Predictors", majority_iris, 2)
self.assertTrue(error.predictors_target_mismatch.is_shown())
Expand All @@ -118,4 +120,53 @@ def test_no_class_on_test(self):
self.send_signal("Predictors", None, 2)
self.send_signal("Data", titanic)
out = self.get_output("Predictions")
np.testing.assert_allclose(out.get_column_view("majority")[0], 0)
np.testing.assert_allclose(out.get_column_view("constant")[0], 0)

def test_bad_data(self):
"""
Firstly it creates predictions with TreeLearner. Then sends predictions and
different data with different domain to Predictions widget. Those different
data and domain are similar to original data and domain but they have three
different target values instead of two.
GH-2129
"""
Variable._clear_all_caches()

filestr1 = """\
age\tsex\tsurvived
d\td\td
\t\tclass
adult\tmale\tyes
adult\tfemale\tno
child\tmale\tyes
child\tfemale\tyes
"""
file1 = io.StringIO(filestr1)
table = TabReader(file1).read()
learner = TreeLearner()
tree = learner(table)

filestr2 = """\
age\tsex\tsurvived
d\td\td
\t\tclass
adult\tmale\tyes
adult\tfemale\tno
child\tmale\tyes
child\tfemale\tunknown
"""
file2 = io.StringIO(filestr2)
bad_table = TabReader(file2).read()

self.send_signal("Predictors", tree, 1)

with excepthook_catch():
self.send_signal("Data", bad_table)

Variable._clear_all_caches() # so that test excepting standard titanic work

def test_continuous_class(self):
data = Table("housing")
cl_data = ConstantLearner()(data)
self.send_signal("Predictors", cl_data, 1)
self.send_signal("Data", data)