Skip to content

Commit 46961b4

Browse files
authored
Merge pull request #1499 from VesnaT/svm_tests
[ENH] Upgrade OWSvm unittests
2 parents b0793ac + f882602 commit 46961b4

File tree

5 files changed

+194
-157
lines changed

5 files changed

+194
-157
lines changed

Orange/widgets/classify/owsvmclassification.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ class OWBaseSVM(OWBaseLearner):
1717
#: kernel degree
1818
degree = settings.Setting(3)
1919
#: gamma
20-
gamma = settings.Setting(1.0)
20+
gamma = settings.Setting(0.0)
2121
#: coef0 (adative constant)
2222
coef0 = settings.Setting(0.0)
2323

2424
#: numerical tolerance
2525
tol = settings.Setting(0.001)
2626

27+
_default_gamma = "auto"
2728
kernels = (("Linear", "x⋅y"),
2829
("Polynomial", "(g x⋅y + c)<sup>d</sup>"),
2930
("RBF", "exp(-g|x-y|²)"),
@@ -50,6 +51,7 @@ def _add_kernel_box(self):
5051
inbox = gui.vBox(spbox)
5152
gamma = gui.doubleSpin(
5253
inbox, self, "gamma", 0.0, 10.0, 0.01, label=" g: ", **common)
54+
gamma.setSpecialValueText(self._default_gamma)
5355
coef0 = gui.doubleSpin(
5456
inbox, self, "coef0", 0.0, 10.0, 0.01, label=" c: ", **common)
5557
degree = gui.doubleSpin(
@@ -66,7 +68,7 @@ def _add_kernel_box(self):
6668
def _add_optimization_box(self):
6769
self.optimization_box = gui.vBox(
6870
self.controlArea, "Optimization Parameters")
69-
gui.doubleSpin(
71+
self.tol_spin = gui.doubleSpin(
7072
self.optimization_box, self, "tol", 1e-6, 1.0, 1e-5,
7173
label="Numerical tolerance:",
7274
decimals=6, alignment=Qt.AlignRight, controlWidth=100,
@@ -94,17 +96,18 @@ def _on_kernel_changed(self):
9496
self.settings_changed()
9597

9698
def _report_kernel_parameters(self, items):
99+
gamma = self.gamma or self._default_gamma
97100
if self.kernel_type == 0:
98101
items["Kernel"] = "Linear"
99102
elif self.kernel_type == 1:
100103
items["Kernel"] = \
101104
"Polynomial, ({g:.4} x⋅y + {c:.4})<sup>{d}</sup>".format(
102-
g=self.gamma, c=self.coef0, d=self.degree)
105+
g=gamma, c=self.coef0, d=self.degree)
103106
elif self.kernel_type == 2:
104-
items["Kernel"] = "RBF, exp(-{:.4}|x-y|²)".format(self.gamma)
107+
items["Kernel"] = "RBF, exp(-{:.4}|x-y|²)".format(gamma)
105108
else:
106109
items["Kernel"] = "Sigmoid, tanh({g:.4} x⋅y + {c:.4})".format(
107-
g=self.gamma, c=self.coef0)
110+
g=gamma, c=self.coef0)
108111

109112
def update_model(self):
110113
super().update_model()
@@ -126,7 +129,7 @@ class OWSVMClassification(OWBaseSVM):
126129

127130
outputs = [("Support vectors", Table)]
128131

129-
# 0: c_svc, 1: nu_svc
132+
C_SVC, Nu_SVC = 0, 1
130133
svmtype = settings.Setting(0)
131134
C = settings.Setting(1.0)
132135
nu = settings.Setting(0.5)
@@ -141,53 +144,51 @@ def _add_type_box(self):
141144
self.controlArea, self, "svmtype", [], box="SVM Type",
142145
orientation=form, callback=self.settings_changed)
143146

144-
form.addWidget(gui.appendRadioButton(box, "C-SVM", addToLayout=False),
145-
0, 0, Qt.AlignLeft)
146-
form.addWidget(QtGui.QLabel("Cost (C):"),
147-
0, 1, Qt.AlignRight)
148-
form.addWidget(gui.doubleSpin(box, self, "C", 1e-3, 1000.0, 0.1,
149-
decimals=3, alignment=Qt.AlignRight,
150-
controlWidth=80, addToLayout=False,
151-
callback=self.settings_changed),
152-
0, 2)
153-
154-
form.addWidget(gui.appendRadioButton(box, "ν-SVM", addToLayout=False),
155-
1, 0, Qt.AlignLeft)
156-
form.addWidget(QtGui.QLabel("Complexity (ν):"),
157-
1, 1, Qt.AlignRight)
158-
form.addWidget(gui.doubleSpin(box, self, "nu", 0.05, 1.0, 0.05,
159-
decimals=2, alignment=Qt.AlignRight,
160-
controlWidth=80, addToLayout=False,
161-
callback=self.settings_changed),
162-
1, 2)
147+
self.c_radio = gui.appendRadioButton(box, "C-SVM", addToLayout=False)
148+
self.nu_radio = gui.appendRadioButton(box, "ν-SVM", addToLayout=False)
149+
self.c_spin = gui.doubleSpin(
150+
box, self, "C", 1e-3, 1000.0, 0.1, decimals=3,
151+
alignment=Qt.AlignRight, controlWidth=80, addToLayout=False,
152+
callback=self.settings_changed)
153+
self.nu_spin = gui.doubleSpin(
154+
box, self, "nu", 0.05, 1.0, 0.05, decimals=2,
155+
alignment=Qt.AlignRight, controlWidth=80, addToLayout=False,
156+
callback=self.settings_changed)
157+
form.addWidget(self.c_radio, 0, 0, Qt.AlignLeft)
158+
form.addWidget(QtGui.QLabel("Cost (C):"), 0, 1, Qt.AlignRight)
159+
form.addWidget(self.c_spin, 0, 2)
160+
form.addWidget(self.nu_radio, 1, 0, Qt.AlignLeft)
161+
form.addWidget(QtGui.QLabel("Complexity (ν):"), 1, 1, Qt.AlignRight)
162+
form.addWidget(self.nu_spin, 1, 2)
163163

164164
def _add_optimization_box(self):
165165
super()._add_optimization_box()
166-
gui.spin(self.optimization_box, self, "max_iter", 50, 1e6, 50,
167-
label="Iteration limit:", checked="limit_iter",
168-
alignment=Qt.AlignRight, controlWidth=100,
169-
callback=self.settings_changed)
166+
self.max_iter_spin = gui.spin(
167+
self.optimization_box, self, "max_iter", 50, 1e6, 50,
168+
label="Iteration limit:", checked="limit_iter",
169+
alignment=Qt.AlignRight, controlWidth=100,
170+
callback=self.settings_changed)
170171

171172
def create_learner(self):
172173
kernel = ["linear", "poly", "rbf", "sigmoid"][self.kernel_type]
173174
common_args = dict(
174175
kernel=kernel,
175176
degree=self.degree,
176-
gamma=self.gamma,
177+
gamma=self.gamma or self._default_gamma,
177178
coef0=self.coef0,
178179
tol=self.tol,
179180
max_iter=self.max_iter if self.limit_iter else -1,
180181
probability=True,
181182
preprocessors=self.preprocessors
182183
)
183-
if self.svmtype == 0:
184+
if self.svmtype == OWSVMClassification.C_SVC:
184185
return SVMLearner(C=self.C, **common_args)
185186
else:
186187
return NuSVMLearner(nu=self.nu, **common_args)
187188

188189
def get_learner_parameters(self):
189190
items = OrderedDict()
190-
if self.svmtype == 0:
191+
if self.svmtype == OWSVMClassification.C_SVC:
191192
items["SVM type"] = "C-SVM, C={}".format(self.C)
192193
else:
193194
items["SVM type"] = "ν-SVM, ν={}".format(self.nu)
Lines changed: 69 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,82 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3-
from PyQt4 import QtGui
4-
5-
from Orange.data import Table
63
from Orange.widgets.classify.owsvmclassification import OWSVMClassification
7-
from Orange.widgets.tests.base import WidgetTest
4+
from Orange.widgets.tests.base import (WidgetTest, DefaultParameterMapping,
5+
ParameterMapping, WidgetLearnerTestMixin)
86

97

10-
class TestOWSVMClassification(WidgetTest):
8+
class TestOWSVMClassification(WidgetTest, WidgetLearnerTestMixin):
119
def setUp(self):
12-
self.widget = self.create_widget(OWSVMClassification)
13-
self.widget.spin_boxes = self.widget.findChildren(QtGui.QDoubleSpinBox)
14-
# max iter spin
15-
self.widget.spin_boxes.append(self.widget.findChildren(QtGui.QSpinBox)[0])
16-
# max iter checkbox
17-
self.widget.max_iter_check_box = self.widget.findChildren(QtGui.QCheckBox)[0]
18-
self.spin_boxes = self.widget.spin_boxes
19-
self.event_data = None
20-
21-
def test_kernel_equation_run(self):
22-
""" Check if right text is written for specific kernel """
23-
for i in range(0, 4):
24-
if self.widget.kernel_box.buttons[i].isChecked():
25-
self.assertEqual(self.widget.kernel_eq, self.widget.kernels[i][1])
26-
27-
def test_kernel_equation(self):
28-
""" Check if right text is written for specific kernel after click """
29-
for index in range(0, 4):
30-
self.widget.kernel_box.buttons[index].click()
31-
self.assertEqual(self.widget.kernel_eq, self.widget.kernels[index][1])
32-
33-
def test_kernel_display_run(self):
34-
""" Check if right spinner box for selected kernel are visible after widget start """
35-
for button_pos, value in ((0, [False, False, False]),
36-
(1, [True, True, True]),
37-
(2, [True, False, False]),
38-
(3, [True, True, False])):
39-
if self.widget.kernel_box.buttons[button_pos].isChecked():
40-
self.assertEqual([not self.spin_boxes[i].box.isHidden() for i in range(2, 5)],
41-
value)
42-
break
10+
self.widget = self.create_widget(OWSVMClassification,
11+
stored_settings={"auto_apply": False})
12+
self.init()
13+
gamma_spin = self.widget._kernel_params[0]
14+
values = [self.widget._default_gamma, gamma_spin.maximum()]
4315

44-
def test_kernel_display(self):
45-
""" Check if right spinner box for selected kernel are visible after we select kernel """
46-
for button_pos, value in ((0, [False, False, False]),
47-
(1, [True, True, True]),
48-
(2, [True, False, False]),
49-
(3, [True, True, False])):
50-
self.widget.kernel_box.buttons[button_pos].click()
51-
self.widget.kernel_box.buttons[button_pos].isChecked()
52-
self.assertEqual([not self.spin_boxes[i].box.isHidden() for i in range(2, 5)], value)
16+
def getter():
17+
value = gamma_spin.value()
18+
return gamma_spin.specialValueText() \
19+
if value == gamma_spin.minimum() else value
5320

54-
def test_optimization_box_visible(self):
55-
""" Check if both spinner box is visible after starting widget """
56-
self.assertEqual(self.spin_boxes[5].box.isHidden(), False)
57-
self.assertEqual(self.spin_boxes[6].box.isHidden(), False)
21+
def setter(value):
22+
if value == gamma_spin.specialValueText():
23+
gamma_spin.setValue(gamma_spin.minimum())
24+
else:
25+
gamma_spin.setValue(value)
5826

59-
def test_optimization_box_checked(self):
60-
""" Check if spinner box for iteration limit is enabled or disabled """
61-
for value in (True, False):
62-
self.widget.max_iter_check_box.setChecked(value)
63-
self.assertEqual(self.widget.max_iter_check_box.isChecked(), value)
64-
self.assertEqual(self.spin_boxes[6].isEnabled(), value)
27+
self.parameters = [
28+
ParameterMapping("C", self.widget.c_spin),
29+
ParameterMapping("gamma", self.widget._kernel_params[0],
30+
values=values, setter=setter, getter=getter),
31+
ParameterMapping("coef0", self.widget._kernel_params[1]),
32+
ParameterMapping("degree", self.widget._kernel_params[2]),
33+
ParameterMapping("tol", self.widget.tol_spin),
34+
ParameterMapping("max_iter", self.widget.max_iter_spin[1])]
6535

66-
def test_type_button_checked(self):
67-
""" Check if SVM type is selected after click """
68-
self.widget.type_box.buttons[0].click()
69-
self.assertEqual(self.widget.type_box.buttons[0].isChecked(), True)
70-
self.widget.type_box.buttons[1].click()
71-
self.assertEqual(self.widget.type_box.buttons[1].isChecked(), True)
36+
def test_parameters_unchecked(self):
37+
"""Check learner and model for various values of all parameters
38+
when Iteration limit is not checked
39+
"""
40+
self.widget.max_iter_spin[0].setCheckState(False)
41+
self.parameters[-1] = DefaultParameterMapping("max_iter", -1)
42+
self.test_parameters()
7243

73-
def test_type_button_properties_visible(self):
74-
""" Check if spinner box in SVM type are visible """
75-
self.assertEqual(not self.spin_boxes[0].isHidden(), True)
76-
self.assertEqual(not self.spin_boxes[1].isHidden(), True)
44+
def test_parameters_svm_type(self):
45+
"""Check learner and model for various values of all parameters
46+
when NuSVM is chosen
47+
"""
48+
self.assertEqual(self.widget.svmtype, OWSVMClassification.C_SVC)
49+
# setChecked(True) does not trigger callback event
50+
self.widget.nu_radio.click()
51+
self.assertEqual(self.widget.svmtype, OWSVMClassification.Nu_SVC)
52+
self.parameters[0] = ParameterMapping("nu", self.widget.nu_spin)
53+
self.test_parameters()
7754

78-
def test_data_before_apply(self):
79-
""" Check if data are set """
80-
self.widget.set_data(Table("iris")[:100])
81-
self.widget.apply()
82-
self.assertEqual(len(self.widget.data), 100)
83-
84-
def test_output_signal_learner(self):
85-
""" Check if we have on output learner """
86-
self.widget.kernel_box.buttons[0].click()
87-
self.widget.set_data(Table("iris")[:100])
88-
self.widget.apply()
89-
self.assertNotEqual(self.widget.learner, None)
55+
def test_kernel_equation(self):
56+
"""Check if the right equation is written according to kernel """
57+
for i in range(4):
58+
if self.widget.kernel_box.buttons[i].isChecked():
59+
self.assertEqual(self.widget.kernel_eq,
60+
self.widget.kernels[i][1])
61+
break
62+
for i in range(4):
63+
self.widget.kernel_box.buttons[i].click()
64+
self.assertEqual(self.widget.kernel_eq, self.widget.kernels[i][1])
9065

91-
def test_output_params(self):
92-
""" Check ouput params """
93-
self.widget.kernel_box.buttons[0].click()
94-
self.widget.set_data(Table("iris")[:100])
95-
self.widget.max_iter_check_box.setChecked(True)
96-
self.widget.apply()
97-
self.widget.type_box.buttons[0].click()
98-
params = self.widget.learner.params
99-
self.assertEqual(params.get('C'), self.spin_boxes[0].value())
100-
self.widget.type_box.buttons[1].click()
101-
params = self.widget.learner.params
102-
self.assertEqual(params.get('nu'), self.spin_boxes[1].value())
103-
self.assertEqual(params.get('gamma'), self.spin_boxes[2].value())
104-
self.assertEqual(params.get('coef0'), self.spin_boxes[3].value())
105-
self.assertEqual(params.get('degree'), self.spin_boxes[4].value())
106-
self.assertEqual(params.get('tol'), self.spin_boxes[5].value())
107-
self.assertEqual(params.get('max_iter'), self.spin_boxes[6].value())
66+
def test_kernel_spins(self):
67+
"""Check if the right spins are visible according to kernel """
68+
for i, hidden in enumerate([[True, True, True],
69+
[False, False, False],
70+
[False, True, True],
71+
[False, False, True]]):
72+
if self.widget.kernel_box.buttons[i].isChecked():
73+
self.assertEqual([self.widget._kernel_params[j].box.isHidden()
74+
for j in range(3)], hidden)
75+
break
76+
for i, hidden in enumerate([[True, True, True],
77+
[False, False, False],
78+
[False, True, True],
79+
[False, False, True]]):
80+
self.widget.kernel_box.buttons[i].click()
81+
self.assertEqual([self.widget._kernel_params[j].box.isHidden()
82+
for j in range(3)], hidden)

Orange/widgets/regression/owsvmregression.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,38 +40,35 @@ def _add_type_box(self):
4040
self.controlArea, self, "svrtype", [], box="SVR Type",
4141
orientation=form)
4242

43-
form.addWidget(gui.appendRadioButton(box, "ε-SVR", addToLayout=False),
44-
0, 0, Qt.AlignLeft)
45-
form.addWidget(QtGui.QLabel("Cost (C):"),
46-
0, 1, Qt.AlignRight)
47-
form.addWidget(gui.doubleSpin(box, self, "epsilon_C", 0.1, 512.0, 0.1,
48-
decimals=2, addToLayout=False),
49-
0, 2)
50-
form.addWidget(QLabel("Loss epsilon (ε):"),
51-
1, 1, Qt.AlignRight)
52-
form.addWidget(gui.doubleSpin(box, self, "epsilon", 0.1, 512.0, 0.1,
53-
decimals=2, addToLayout=False),
54-
1, 2)
43+
self.epsilon_radio = gui.appendRadioButton(box, "ε-SVR",
44+
addToLayout=False)
45+
self.epsilon_C_spin = gui.doubleSpin(box, self, "epsilon_C", 0.1, 512.0,
46+
0.1, decimals=2, addToLayout=False)
47+
self.epsilon_spin = gui.doubleSpin(box, self, "epsilon", 0.1, 512.0,
48+
0.1, decimals=2, addToLayout=False)
49+
form.addWidget(self.epsilon_radio, 0, 0, Qt.AlignLeft)
50+
form.addWidget(QtGui.QLabel("Cost (C):"), 0, 1, Qt.AlignRight)
51+
form.addWidget(self.epsilon_C_spin, 0, 2)
52+
form.addWidget(QLabel("Loss epsilon (ε):"), 1, 1, Qt.AlignRight)
53+
form.addWidget(self.epsilon_spin, 1, 2)
5554

56-
form.addWidget(gui.appendRadioButton(box, "ν-SVR", addToLayout=False),
57-
2, 0, Qt.AlignLeft)
58-
form.addWidget(QLabel("Cost (C):"),
59-
2, 1, Qt.AlignRight)
60-
form.addWidget(gui.doubleSpin(box, self, "nu_C", 0.1, 512.0, 0.1,
61-
decimals=2, addToLayout=False),
62-
2, 2)
63-
form.addWidget(QLabel("Complexity bound (ν):"),
64-
3, 1, Qt.AlignRight)
65-
form.addWidget(gui.doubleSpin(box, self, "nu", 0.05, 1.0, 0.05,
66-
decimals=2, addToLayout=False),
67-
3, 2)
55+
self.nu_radio = gui.appendRadioButton(box, "ν-SVR", addToLayout=False)
56+
self.nu_C_spin = gui.doubleSpin(box, self, "nu_C", 0.1, 512.0, 0.1,
57+
decimals=2, addToLayout=False)
58+
self.nu_spin = gui.doubleSpin(box, self, "nu", 0.05, 1.0, 0.05,
59+
decimals=2, addToLayout=False)
60+
form.addWidget(self.nu_radio, 2, 0, Qt.AlignLeft)
61+
form.addWidget(QLabel("Cost (C):"), 2, 1, Qt.AlignRight)
62+
form.addWidget(self.nu_C_spin, 2, 2)
63+
form.addWidget(QLabel("Complexity bound (ν):"), 3, 1, Qt.AlignRight)
64+
form.addWidget(self.nu_spin, 3, 2)
6865

6966
def create_learner(self):
7067
kernel = ["linear", "poly", "rbf", "sigmoid"][self.kernel_type]
7168
common_args = dict(
7269
kernel=kernel,
7370
degree=self.degree,
74-
gamma=self.gamma,
71+
gamma=self.gamma if self.gamma else self._default_gamma,
7572
coef0=self.coef0,
7673
tol=self.tol,
7774
preprocessors=self.preprocessors

0 commit comments

Comments
 (0)