Skip to content

Commit c5d8f21

Browse files
authored
Merge pull request #4011 from VesnaT/correlations_progress_bar
[FIX] Correlations: Add progress bar, retain responsiveness
2 parents 2060147 + 7a52bc6 commit c5d8f21

File tree

6 files changed

+93
-16
lines changed

6 files changed

+93
-16
lines changed

Orange/widgets/data/owcorrelations.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,38 @@ def stopped(self):
196196
header = self.rank_table.horizontalHeader()
197197
header.setSectionResizeMode(1, QHeaderView.Stretch)
198198

199+
def start(self, task, *args, **kwargs):
200+
self.__set_state_ready()
201+
super().start(task, *args, **kwargs)
202+
self.__set_state_busy()
203+
204+
def cancel(self):
205+
super().cancel()
206+
self.__set_state_ready()
207+
208+
def _connect_signals(self, state):
209+
super()._connect_signals(state)
210+
state.progress_changed.connect(self.master.progressBarSet)
211+
state.status_changed.connect(self.master.setStatusMessage)
212+
213+
def _disconnect_signals(self, state):
214+
super()._disconnect_signals(state)
215+
state.progress_changed.disconnect(self.master.progressBarSet)
216+
state.status_changed.disconnect(self.master.setStatusMessage)
217+
218+
def _on_task_done(self, future):
219+
super()._on_task_done(future)
220+
self.__set_state_ready()
221+
222+
def __set_state_ready(self):
223+
self.master.progressBarFinished()
224+
self.master.setBlocking(False)
225+
self.master.setStatusMessage("")
226+
227+
def __set_state_busy(self):
228+
self.master.progressBarInit()
229+
self.master.setBlocking(True)
230+
199231

200232
class OWCorrelations(OWWidget):
201233
name = "Correlations"
@@ -250,7 +282,6 @@ def __init__(self):
250282

251283
self.vizrank, _ = CorrelationRank.add_vizrank(
252284
None, self, None, self._vizrank_selection_changed)
253-
self.vizrank.progressBar = self.progressBar
254285
self.vizrank.button.setEnabled(False)
255286
self.vizrank.threadStopped.connect(self._vizrank_stopped)
256287

Orange/widgets/data/tests/test_owcorrelations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def test_feature_combo(self):
229229
if attr.is_continuous]
230230
self.assertEqual(len(feature_combo.model()), len(cont_attributes) + 1)
231231

232+
self.wait_until_stop_blocking()
232233
self.send_signal(self.widget.Inputs.data, Table("housing"))
233234
self.assertEqual(len(feature_combo.model()), 14)
234235

@@ -281,6 +282,7 @@ def test_send_report(self):
281282
"""Test report """
282283
self.send_signal(self.widget.Inputs.data, self.data_cont)
283284
self.widget.report_button.click()
285+
self.wait_until_stop_blocking()
284286
self.send_signal(self.widget.Inputs.data, None)
285287
self.widget.report_button.click()
286288

Orange/widgets/visualize/owlinearprojection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(self, master):
5959
gui.rubber(box)
6060
self.last_run_n_attrs = None
6161
self.attr_color = master.attr_color
62+
self.attrs = []
6263

6364
def initialize(self):
6465
super().initialize()
@@ -95,6 +96,8 @@ def check_preconditions(self):
9596

9697
def state_count(self):
9798
n_all_attrs = len(self.attrs)
99+
if not n_all_attrs:
100+
return 0
98101
n_attrs = self.n_attrs
99102
return factorial(n_all_attrs) // (2 * factorial(n_all_attrs - n_attrs) * n_attrs)
100103

Orange/widgets/visualize/tests/test_owlinearprojection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,15 @@ def setUp(self):
203203
def test_discrete_class(self):
204204
self.send_signal(self.widget.Inputs.data, self.data)
205205
run_vizrank(self.vizrank.compute_score,
206-
self.vizrank.iterate_states(None), [], Mock())
206+
self.vizrank.iterate_states, None,
207+
[], 0, self.vizrank.state_count(), Mock())
207208

208209
def test_continuous_class(self):
209210
data = Table("housing")[::100]
210211
self.send_signal(self.widget.Inputs.data, data)
211212
run_vizrank(self.vizrank.compute_score,
212-
self.vizrank.iterate_states(None), [], Mock())
213+
self.vizrank.iterate_states, None,
214+
[], 0, self.vizrank.state_count(), Mock())
213215

214216
def test_set_attrs(self):
215217
self.send_signal(self.widget.Inputs.data, self.data)

Orange/widgets/visualize/tests/test_vizrankdialog.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def test_run_vizrank(self):
3131
# run through all states
3232
task.is_interruption_requested.return_value = False
3333
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
34-
res = run_vizrank(compute_score, chain(states), scores, task)
34+
res = run_vizrank(compute_score, lambda initial: chain(states),
35+
None, scores, 0, 6, task)
3536

3637
next_state = self.assertQueueEqual(
3738
res.queue, [0, 0, 0, 3, 2, 5], compute_score,
@@ -40,28 +41,33 @@ def test_run_vizrank(self):
4041
res_scores = sorted([compute_score(x) for x in states])
4142
self.assertListEqual(res.scores, res_scores)
4243
self.assertIsNot(scores, res.scores)
43-
self.assertEqual(task.set_partial_result.call_count, 6)
44+
self.assertEqual(task.set_partial_result.call_count, 2)
45+
self.assertEqual(task.set_progress_value.call_count, 7)
4446

4547
def test_run_vizrank_interrupt(self):
4648
scores, task = [], Mock()
4749
# interrupt calculation in third iteration
4850
task.is_interruption_requested.side_effect = lambda: \
4951
True if task.is_interruption_requested.call_count > 2 else False
5052
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
51-
res = run_vizrank(compute_score, chain(states), scores, task)
53+
res = run_vizrank(compute_score, lambda initial: chain(states),
54+
None, scores, 0, 6, task)
5255

5356
next_state = self.assertQueueEqual(
5457
res.queue, [0, 0], compute_score, states[:2], states[1:3])
5558
self.assertEqual(next_state, (0, 3))
5659
res_scores = sorted([compute_score(x) for x in states[:2]])
5760
self.assertListEqual(res.scores, res_scores)
5861
self.assertIsNot(scores, res.scores)
59-
self.assertEqual(task.set_partial_result.call_count, 2)
62+
self.assertEqual(task.set_partial_result.call_count, 1)
63+
self.assertEqual(task.set_progress_value.call_count, 3)
64+
task.set_progress_value.assert_called_with(int(1 / 6 * 100))
6065

6166
# continue calculation through all states
6267
task.is_interruption_requested.side_effect = lambda: False
6368
i = states.index(next_state)
64-
res = run_vizrank(compute_score, chain(states[i:]), res_scores, task)
69+
res = run_vizrank(compute_score, lambda initial: chain(states[i:]),
70+
None, res_scores, 2, 6, task)
6571

6672
next_state = self.assertQueueEqual(
6773
res.queue, [0, 3, 2, 5], compute_score, states[2:],
@@ -70,7 +76,9 @@ def test_run_vizrank_interrupt(self):
7076
res_scores = sorted([compute_score(x) for x in states])
7177
self.assertListEqual(res.scores, res_scores)
7278
self.assertIsNot(scores, res.scores)
73-
self.assertEqual(task.set_partial_result.call_count, 6)
79+
self.assertEqual(task.set_partial_result.call_count, 3)
80+
self.assertEqual(task.set_progress_value.call_count, 8)
81+
task.set_progress_value.assert_called_with(int(5 / 6 * 100))
7482

7583
def assertQueueEqual(self, queue, positions, f, states, next_states):
7684
self.assertIsInstance(queue, Queue)
@@ -95,8 +103,12 @@ def iterate_states(initial_state):
95103
def invoke_on_partial_result():
96104
widget.on_partial_result(run_vizrank(
97105
widget.compute_score,
98-
widget.iterate_states(widget.saved_state),
99-
widget.scores, task
106+
widget.iterate_states,
107+
widget.saved_state,
108+
widget.scores,
109+
widget.saved_progress,
110+
widget.state_count(),
111+
task
100112
))
101113

102114
task = Mock()
@@ -107,6 +119,7 @@ def invoke_on_partial_result():
107119
widget.compute_score = compute_score
108120
widget.iterate_states = iterate_states
109121
widget.row_for_state = lambda sc, _: [QStandardItem(str(sc))]
122+
widget.state_count = lambda: len(states)
110123

111124
# interrupt calculation in third iteration
112125
task.is_interruption_requested.side_effect = lambda: \
@@ -117,6 +130,7 @@ def invoke_on_partial_result():
117130
sorted([compute_score(x) for x in states[:2]])):
118131
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
119132
self.assertEqual(widget.saved_progress, 2)
133+
task.set_progress_value.assert_called_with(int(1 / 6 * 100))
120134

121135
# continue calculation through all states
122136
task.is_interruption_requested.side_effect = lambda: False
@@ -126,6 +140,7 @@ def invoke_on_partial_result():
126140
sorted([compute_score(x) for x in states])):
127141
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
128142
self.assertEqual(widget.saved_progress, 6)
143+
task.set_progress_value.assert_called_with(int(5 / 6 * 100))
129144

130145

131146
if __name__ == "__main__":

Orange/widgets/visualize/utils/__init__.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from operator import attrgetter
77
from queue import Queue, Empty
88
from types import SimpleNamespace as namespace
9-
from typing import Optional, Iterable, List, Callable, Iterator
9+
from typing import Optional, Iterable, List, Callable
10+
from threading import Timer
1011

1112
from AnyQt.QtCore import Qt, QSize, pyqtSignal as Signal, QSortFilterProxyModel
1213
from AnyQt.QtGui import QStandardItemModel, QStandardItem, QColor, QBrush, QPen
@@ -366,7 +367,8 @@ def toggle(self):
366367
self.progressBarInit()
367368
self.before_running()
368369
self.start(run_vizrank, self.compute_score,
369-
self.iterate_states(self.saved_state), self.scores)
370+
self.iterate_states, self.saved_state, self.scores,
371+
self.saved_progress, self.state_count())
370372
else:
371373
self.button.setText("Continue")
372374
self.button.repaint()
@@ -382,10 +384,17 @@ def stopped(self):
382384
pass
383385

384386

385-
def run_vizrank(compute_score: Callable, states: Iterator,
386-
scores: List, task: TaskState):
387+
def run_vizrank(compute_score: Callable, iterate_states: Callable,
388+
saved_state: Optional[Iterable], scores: List,
389+
progress: int, state_count: int, task: TaskState):
390+
task.set_status("Getting combinations...")
391+
task.set_progress_value(0.1)
392+
states = iterate_states(saved_state)
393+
394+
task.set_status("Getting scores...")
387395
res = Result(queue=Queue(), scores=None)
388396
scores = scores.copy()
397+
can_set_partial_result = True
389398

390399
def do_work(st, next_st):
391400
try:
@@ -398,19 +407,34 @@ def do_work(st, next_st):
398407
except Exception: # ignore current state in case of any problem
399408
pass
400409
res.scores = scores.copy()
401-
task.set_partial_result(res)
410+
411+
def reset_flag():
412+
nonlocal can_set_partial_result
413+
can_set_partial_result = True
402414

403415
state = None
404416
next_state = next(states)
405417
try:
406418
while True:
407419
if task.is_interruption_requested():
408420
return res
421+
task.set_progress_value(int(progress * 100 / max(1, state_count)))
422+
progress += 1
409423
state = copy.copy(next_state)
410424
next_state = copy.copy(next(states))
411425
do_work(state, next_state)
426+
# for simple scores (e.g. correlations widget) and many feature
427+
# combinations, the 'partial_result_ready' signal (emitted by
428+
# invoking 'task.set_partial_result') was emitted too frequently
429+
# for a longer period of time and therefore causing the widget
430+
# being unresponsive
431+
if can_set_partial_result:
432+
task.set_partial_result(res)
433+
can_set_partial_result = False
434+
Timer(0.01, reset_flag).start()
412435
except StopIteration:
413436
do_work(state, None)
437+
task.set_partial_result(res)
414438
return res
415439

416440

0 commit comments

Comments
 (0)