Skip to content

Commit b5d6f5a

Browse files
authored
Merge pull request #3673 from VesnaT/transform_enh
[FIX] Transform: Replace 'Preprocess' input with 'Template Data' input
2 parents 021604b + 7ae3285 commit b5d6f5a

File tree

9 files changed

+129
-197
lines changed

9 files changed

+129
-197
lines changed

Orange/preprocess/preprocess.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -536,18 +536,6 @@ def transform(var):
536536
return data.transform(domain)
537537

538538

539-
class ApplyDomain(Preprocess):
540-
def __init__(self, domain, name):
541-
self._domain = domain
542-
self._name = name
543-
544-
def __call__(self, data):
545-
return data.transform(self._domain)
546-
547-
def __str__(self):
548-
return self._name
549-
550-
551539
class PreprocessorList(Preprocess):
552540
"""
553541
Store a list of preprocessors and on call apply them to the dataset.

Orange/widgets/data/owtransform.py

Lines changed: 64 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,77 @@
1+
from typing import Optional
2+
13
import numpy as np
24

35
from Orange.data import Table, Domain
4-
from Orange.preprocess.preprocess import Preprocess, Discretize
56
from Orange.widgets import gui
7+
from Orange.widgets.report.report import describe_data
68
from Orange.widgets.settings import Setting
79
from Orange.widgets.utils.sql import check_sql_input
810
from Orange.widgets.utils.widgetpreview import WidgetPreview
911
from Orange.widgets.widget import OWWidget, Input, Output, Msg
1012

1113

1214
class OWTransform(OWWidget):
13-
name = "Transform"
14-
description = "Transform data table."
15+
name = "Apply Domain"
16+
description = "Applies template domain on data table."
1517
icon = "icons/Transform.svg"
1618
priority = 2110
17-
keywords = []
19+
keywords = ["transform"]
1820

1921
retain_all_data = Setting(False)
2022

2123
class Inputs:
2224
data = Input("Data", Table, default=True)
23-
preprocessor = Input("Preprocessor", Preprocess)
25+
template_data = Input("Template Data", Table)
2426

2527
class Outputs:
2628
transformed_data = Output("Transformed Data", Table)
2729

2830
class Error(OWWidget.Error):
29-
pp_error = Msg("An error occurred while transforming data.\n{}")
31+
error = Msg("An error occurred while transforming data.\n{}")
3032

3133
resizing_enabled = False
3234
want_main_area = False
3335

3436
def __init__(self):
3537
super().__init__()
36-
self.data = None
37-
self.preprocessor = None
38-
self.transformed_data = None
38+
self.data = None # type: Optional[Table]
39+
self.template_domain = None # type: Optional[Domain]
40+
self.transformed_info = describe_data(None) # type: OrderedDict
3941

4042
info_box = gui.widgetBox(self.controlArea, "Info")
4143
self.input_label = gui.widgetLabel(info_box, "")
42-
self.preprocessor_label = gui.widgetLabel(info_box, "")
44+
self.template_label = gui.widgetLabel(info_box, "")
4345
self.output_label = gui.widgetLabel(info_box, "")
4446
self.set_input_label_text()
45-
self.set_preprocessor_label_text()
47+
self.set_template_label_text()
4648

47-
self.retain_all_data_cb = gui.checkBox(
48-
self.controlArea, self, "retain_all_data", label="Retain all data",
49-
callback=self.apply
50-
)
49+
box = gui.widgetBox(self.controlArea, "Output")
50+
gui.checkBox(box, self, "retain_all_data", "Retain all data",
51+
callback=self.apply)
5152

5253
def set_input_label_text(self):
5354
text = "No data on input."
54-
if self.data is not None:
55+
if self.data:
5556
text = "Input data with {:,} instances and {:,} features.".format(
5657
len(self.data),
5758
len(self.data.domain.attributes))
5859
self.input_label.setText(text)
5960

60-
def set_preprocessor_label_text(self):
61-
text = "No preprocessor on input."
62-
if self.transformed_data is not None:
63-
text = "Preprocessor {} applied.".format(self.preprocessor)
64-
elif self.preprocessor is not None:
65-
text = "Preprocessor {} on input.".format(self.preprocessor)
66-
self.preprocessor_label.setText(text)
61+
def set_template_label_text(self):
62+
text = "No template data on input."
63+
if self.data and self.template_domain is not None:
64+
text = "Template domain applied."
65+
elif self.template_domain is not None:
66+
text = "Template data includes {:,} features.".format(
67+
len(self.template_domain.attributes))
68+
self.template_label.setText(text)
6769

68-
def set_output_label_text(self):
70+
def set_output_label_text(self, data):
6971
text = ""
70-
if self.transformed_data:
72+
if data:
7173
text = "Output data includes {:,} features.".format(
72-
len(self.transformed_data.domain.attributes))
74+
len(data.domain.attributes))
7375
self.output_label.setText(text)
7476

7577
@Inputs.data
@@ -78,56 +80,53 @@ def set_data(self, data):
7880
self.data = data
7981
self.set_input_label_text()
8082

81-
@Inputs.preprocessor
82-
def set_preprocessor(self, preprocessor):
83-
self.preprocessor = preprocessor
83+
@Inputs.template_data
84+
@check_sql_input
85+
def set_template_data(self, data):
86+
self.template_domain = data and data.domain
8487

8588
def handleNewSignals(self):
8689
self.apply()
8790

8891
def apply(self):
8992
self.clear_messages()
90-
self.transformed_data = None
91-
if self.data is not None and self.preprocessor is not None:
93+
transformed_data = None
94+
if self.data and self.template_domain is not None:
9295
try:
93-
self.transformed_data = self.preprocessor(self.data)
94-
except Exception as ex: # pylint: disable=broad-except
95-
self.Error.pp_error(ex)
96-
97-
if self.retain_all_data:
98-
self.Outputs.transformed_data.send(self.merge_data())
99-
else:
100-
self.Outputs.transformed_data.send(self.transformed_data)
101-
102-
self.set_preprocessor_label_text()
103-
self.set_output_label_text()
104-
105-
def merge_data(self):
106-
attributes = getattr(self.data.domain, 'attributes')
107-
cls_vars = getattr(self.data.domain, 'class_vars')
108-
metas_v = getattr(self.data.domain, 'metas')\
109-
+ getattr(self.transformed_data.domain, 'attributes')
110-
domain = Domain(attributes, cls_vars, metas_v)
111-
X = self.data.X
112-
Y = self.data.Y
113-
metas = np.hstack((self.data.metas, self.transformed_data.X))
114-
table = Table.from_numpy(domain, X, Y, metas)
115-
table.name = getattr(self.data, 'name', '')
116-
table.attributes = getattr(self.data, 'attributes', {})
117-
table.ids = self.data.ids
118-
return table
96+
transformed_data = self.data.transform(self.template_domain)
97+
except Exception as ex: # pylint: disable=broad-except
98+
self.Error.error(ex)
99+
100+
data = transformed_data
101+
if data and self.retain_all_data:
102+
data = self.merged_data(data)
103+
self.transformed_info = describe_data(data)
104+
self.Outputs.transformed_data.send(data)
105+
self.set_template_label_text()
106+
self.set_output_label_text(data)
107+
108+
def merged_data(self, t_data):
109+
domain = self.data.domain
110+
t_domain = t_data.domain
111+
metas = domain.metas + t_domain.attributes + t_domain.metas
112+
domain = Domain(domain.attributes, domain.class_vars, metas)
113+
data = self.data.transform(domain)
114+
metas = np.hstack((t_data.X, t_data.metas))
115+
data.metas[:, -metas.shape[1]:] = metas
116+
return data
119117

120118
def send_report(self):
121-
if self.preprocessor is not None:
122-
self.report_items("Settings",
123-
(("Preprocessor", self.preprocessor),))
124-
if self.data is not None:
119+
if self.data:
125120
self.report_data("Data", self.data)
126-
if self.transformed_data is not None:
127-
self.report_data("Transformed data", self.transformed_data)
121+
if self.template_domain is not None:
122+
self.report_domain("Template data", self.template_domain)
123+
if self.transformed_info:
124+
self.report_items("Transformed data", self.transformed_info)
128125

129126

130127
if __name__ == "__main__": # pragma: no cover
128+
from Orange.preprocess import Discretize
129+
130+
table = Table("iris")
131131
WidgetPreview(OWTransform).run(
132-
set_data=Table("iris"),
133-
set_preprocessor=Discretize())
132+
set_data=table, set_template_data=Discretize()(table))
Lines changed: 60 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3+
from unittest.mock import Mock
4+
5+
from numpy import testing as npt
6+
37
from Orange.data import Table
4-
from Orange.preprocess import Discretize
5-
from Orange.preprocess.preprocess import Preprocess
8+
from Orange.preprocess import Discretize, Continuize
69
from Orange.widgets.data.owtransform import OWTransform
710
from Orange.widgets.tests.base import WidgetTest
811
from Orange.widgets.unsupervised.owpca import OWPCA
@@ -12,38 +15,39 @@ class TestOWTransform(WidgetTest):
1215
def setUp(self):
1316
self.widget = self.create_widget(OWTransform)
1417
self.data = Table("iris")
15-
self.preprocessor = Discretize()
18+
self.disc_data = Discretize()(self.data)
1619

1720
def test_output(self):
18-
# send data and preprocessor
19-
self.send_signal(self.widget.Inputs.data, self.data)
20-
self.send_signal(self.widget.Inputs.preprocessor, self.preprocessor)
21+
# send data and template data
22+
self.send_signal(self.widget.Inputs.data, self.data[::15])
23+
self.send_signal(self.widget.Inputs.template_data, self.disc_data)
2124
output = self.get_output(self.widget.Outputs.transformed_data)
22-
self.assertIsInstance(output, Table)
23-
self.assertEqual("Input data with 150 instances and 4 features.",
25+
self.assertTableEqual(output, self.disc_data[::15])
26+
self.assertEqual("Input data with 10 instances and 4 features.",
2427
self.widget.input_label.text())
25-
self.assertEqual("Preprocessor Discretize() applied.",
26-
self.widget.preprocessor_label.text())
28+
self.assertEqual("Template domain applied.",
29+
self.widget.template_label.text())
2730
self.assertEqual("Output data includes 4 features.",
2831
self.widget.output_label.text())
2932

30-
# remove preprocessor
31-
self.send_signal(self.widget.Inputs.preprocessor, None)
33+
# remove template data
34+
self.send_signal(self.widget.Inputs.template_data, None)
3235
output = self.get_output(self.widget.Outputs.transformed_data)
3336
self.assertIsNone(output)
34-
self.assertEqual("Input data with 150 instances and 4 features.",
37+
self.assertEqual("Input data with 10 instances and 4 features.",
3538
self.widget.input_label.text())
36-
self.assertEqual("No preprocessor on input.", self.widget.preprocessor_label.text())
39+
self.assertEqual("No template data on input.",
40+
self.widget.template_label.text())
3741
self.assertEqual("", self.widget.output_label.text())
3842

39-
# send preprocessor
40-
self.send_signal(self.widget.Inputs.preprocessor, self.preprocessor)
43+
# send template data
44+
self.send_signal(self.widget.Inputs.template_data, self.disc_data)
4145
output = self.get_output(self.widget.Outputs.transformed_data)
42-
self.assertIsInstance(output, Table)
43-
self.assertEqual("Input data with 150 instances and 4 features.",
46+
self.assertTableEqual(output, self.disc_data[::15])
47+
self.assertEqual("Input data with 10 instances and 4 features.",
4448
self.widget.input_label.text())
45-
self.assertEqual("Preprocessor Discretize() applied.",
46-
self.widget.preprocessor_label.text())
49+
self.assertEqual("Template domain applied.",
50+
self.widget.template_label.text())
4751
self.assertEqual("Output data includes 4 features.",
4852
self.widget.output_label.text())
4953

@@ -52,49 +56,63 @@ def test_output(self):
5256
output = self.get_output(self.widget.Outputs.transformed_data)
5357
self.assertIsNone(output)
5458
self.assertEqual("No data on input.", self.widget.input_label.text())
55-
self.assertEqual("Preprocessor Discretize() on input.",
56-
self.widget.preprocessor_label.text())
59+
self.assertEqual("Template data includes 4 features.",
60+
self.widget.template_label.text())
5761
self.assertEqual("", self.widget.output_label.text())
5862

59-
# remove preprocessor
60-
self.send_signal(self.widget.Inputs.preprocessor, None)
63+
# remove template data
64+
self.send_signal(self.widget.Inputs.template_data, None)
6165
self.assertEqual("No data on input.", self.widget.input_label.text())
62-
self.assertEqual("No preprocessor on input.",
63-
self.widget.preprocessor_label.text())
66+
self.assertEqual("No template data on input.",
67+
self.widget.template_label.text())
6468
self.assertEqual("", self.widget.output_label.text())
6569

66-
def test_input_pca_preprocessor(self):
70+
def assertTableEqual(self, table1, table2):
71+
self.assertIs(table1.domain, table2.domain)
72+
npt.assert_array_equal(table1.X, table2.X)
73+
npt.assert_array_equal(table1.Y, table2.Y)
74+
npt.assert_array_equal(table1.metas, table2.metas)
75+
76+
def test_input_pca_output(self):
6777
owpca = self.create_widget(OWPCA)
6878
self.send_signal(owpca.Inputs.data, self.data, widget=owpca)
6979
owpca.components_spin.setValue(2)
70-
pp = self.get_output(owpca.Outputs.preprocessor, widget=owpca)
71-
self.assertIsNotNone(pp, Preprocess)
80+
pca_out = self.get_output(owpca.Outputs.transformed_data, widget=owpca)
7281

73-
self.send_signal(self.widget.Inputs.data, self.data)
74-
self.send_signal(self.widget.Inputs.preprocessor, pp)
82+
self.send_signal(self.widget.Inputs.data, self.data[::10])
83+
self.send_signal(self.widget.Inputs.template_data, pca_out)
7584
output = self.get_output(self.widget.Outputs.transformed_data)
76-
self.assertIsInstance(output, Table)
77-
self.assertEqual(output.X.shape, (len(self.data), 2))
85+
npt.assert_array_equal(pca_out.X[::10], output.X)
7886

79-
# test retain data functionality
80-
self.widget.retain_all_data = True
81-
self.widget.apply()
87+
def test_retain_all_data(self):
88+
data = Table("zoo")
89+
cont_data = Continuize()(data)
90+
self.send_signal(self.widget.Inputs.data, data)
91+
self.send_signal(self.widget.Inputs.template_data, cont_data)
92+
self.widget.controls.retain_all_data.click()
8293
output = self.get_output(self.widget.Outputs.transformed_data)
8394
self.assertIsInstance(output, Table)
84-
self.assertEqual(output.X.shape, (len(self.data), 4))
85-
self.assertEqual(output.metas.shape, (len(self.data), 2))
95+
self.assertEqual(output.X.shape, (len(data), 16))
96+
self.assertEqual(output.metas.shape, (len(data), 38))
8697

8798
def test_error_transforming(self):
88-
self.send_signal(self.widget.Inputs.data, self.data)
89-
self.send_signal(self.widget.Inputs.preprocessor, Preprocess())
90-
self.assertTrue(self.widget.Error.pp_error.is_shown())
99+
data = self.data[::10]
100+
data.transform = Mock(side_effect=Exception())
101+
self.send_signal(self.widget.Inputs.data, data)
102+
self.send_signal(self.widget.Inputs.template_data, self.disc_data)
103+
self.assertTrue(self.widget.Error.error.is_shown())
91104
output = self.get_output(self.widget.Outputs.transformed_data)
92105
self.assertIsNone(output)
93106
self.send_signal(self.widget.Inputs.data, None)
94-
self.assertFalse(self.widget.Error.pp_error.is_shown())
107+
self.assertFalse(self.widget.Error.error.is_shown())
95108

96109
def test_send_report(self):
97110
self.send_signal(self.widget.Inputs.data, self.data)
98111
self.widget.report_button.click()
99112
self.send_signal(self.widget.Inputs.data, None)
100113
self.widget.report_button.click()
114+
115+
116+
if __name__ == "__main__":
117+
import unittest
118+
unittest.main()

0 commit comments

Comments
 (0)