Skip to content

Commit a8a6c7f

Browse files
committed
Correlations: Add progress bar
1 parent 2fc7ca5 commit a8a6c7f

File tree

6 files changed

+59
-9
lines changed

6 files changed

+59
-9
lines changed

Orange/widgets/data/owcorrelations.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,32 @@ def stopped(self):
196196
header = self.rank_table.horizontalHeader()
197197
header.setSectionResizeMode(1, QHeaderView.Stretch)
198198

199+
def start(self, task, *args, **kwargs):
200+
self.__set_state_ready()
201+
super().start(task, *args, **kwargs)
202+
self.master.progressBarInit()
203+
self.master.setBlocking(True)
204+
205+
def cancel(self):
206+
super().cancel()
207+
self.__set_state_ready()
208+
209+
def _connect_signals(self, state):
210+
super()._connect_signals(state)
211+
state.progress_changed.connect(self.master.progressBarSet)
212+
213+
def _disconnect_signals(self, state):
214+
super()._disconnect_signals(state)
215+
state.progress_changed.disconnect(self.master.progressBarSet)
216+
217+
def _on_task_done(self, future):
218+
super()._on_task_done(future)
219+
self.__set_state_ready()
220+
221+
def __set_state_ready(self):
222+
self.master.progressBarFinished()
223+
self.master.setBlocking(False)
224+
199225

200226
class OWCorrelations(OWWidget):
201227
name = "Correlations"

Orange/widgets/data/tests/test_owcorrelations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def test_feature_combo(self):
229229
if attr.is_continuous]
230230
self.assertEqual(len(feature_combo.model()), len(cont_attributes) + 1)
231231

232+
self.wait_until_stop_blocking()
232233
self.send_signal(self.widget.Inputs.data, Table("housing"))
233234
self.assertEqual(len(feature_combo.model()), 14)
234235

@@ -281,6 +282,7 @@ def test_send_report(self):
281282
"""Test report """
282283
self.send_signal(self.widget.Inputs.data, self.data_cont)
283284
self.widget.report_button.click()
285+
self.wait_until_stop_blocking()
284286
self.send_signal(self.widget.Inputs.data, None)
285287
self.widget.report_button.click()
286288

Orange/widgets/visualize/owlinearprojection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(self, master):
5959
gui.rubber(box)
6060
self.last_run_n_attrs = None
6161
self.attr_color = master.attr_color
62+
self.attrs = []
6263

6364
def initialize(self):
6465
super().initialize()
@@ -95,6 +96,8 @@ def check_preconditions(self):
9596

9697
def state_count(self):
9798
n_all_attrs = len(self.attrs)
99+
if not n_all_attrs:
100+
return 0
98101
n_attrs = self.n_attrs
99102
return factorial(n_all_attrs) // (2 * factorial(n_all_attrs - n_attrs) * n_attrs)
100103

Orange/widgets/visualize/tests/test_owlinearprojection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,15 @@ def setUp(self):
203203
def test_discrete_class(self):
204204
self.send_signal(self.widget.Inputs.data, self.data)
205205
run_vizrank(self.vizrank.compute_score,
206-
self.vizrank.iterate_states(None), [], Mock())
206+
self.vizrank.iterate_states(None),
207+
[], 0, self.vizrank.state_count(), Mock())
207208

208209
def test_continuous_class(self):
209210
data = Table("housing")[::100]
210211
self.send_signal(self.widget.Inputs.data, data)
211212
run_vizrank(self.vizrank.compute_score,
212-
self.vizrank.iterate_states(None), [], Mock())
213+
self.vizrank.iterate_states(None),
214+
[], 0, self.vizrank.state_count(), Mock())
213215

214216
def test_set_attrs(self):
215217
self.send_signal(self.widget.Inputs.data, self.data)

Orange/widgets/visualize/tests/test_vizrankdialog.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_run_vizrank(self):
3131
# run through all states
3232
task.is_interruption_requested.return_value = False
3333
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
34-
res = run_vizrank(compute_score, chain(states), scores, task)
34+
res = run_vizrank(compute_score, chain(states), scores, 0, 6, task)
3535

3636
next_state = self.assertQueueEqual(
3737
res.queue, [0, 0, 0, 3, 2, 5], compute_score,
@@ -41,14 +41,15 @@ def test_run_vizrank(self):
4141
self.assertListEqual(res.scores, res_scores)
4242
self.assertIsNot(scores, res.scores)
4343
self.assertEqual(task.set_partial_result.call_count, 6)
44+
self.assertEqual(task.set_progress_value.call_count, 6)
4445

4546
def test_run_vizrank_interrupt(self):
4647
scores, task = [], Mock()
4748
# interrupt calculation in third iteration
4849
task.is_interruption_requested.side_effect = lambda: \
4950
True if task.is_interruption_requested.call_count > 2 else False
5051
states = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
51-
res = run_vizrank(compute_score, chain(states), scores, task)
52+
res = run_vizrank(compute_score, chain(states), scores, 0, 6, task)
5253

5354
next_state = self.assertQueueEqual(
5455
res.queue, [0, 0], compute_score, states[:2], states[1:3])
@@ -57,11 +58,14 @@ def test_run_vizrank_interrupt(self):
5758
self.assertListEqual(res.scores, res_scores)
5859
self.assertIsNot(scores, res.scores)
5960
self.assertEqual(task.set_partial_result.call_count, 2)
61+
self.assertEqual(task.set_progress_value.call_count, 2)
62+
task.set_progress_value.assert_called_with(int(1 / 6 * 100))
6063

6164
# continue calculation through all states
6265
task.is_interruption_requested.side_effect = lambda: False
6366
i = states.index(next_state)
64-
res = run_vizrank(compute_score, chain(states[i:]), res_scores, task)
67+
res = run_vizrank(compute_score, chain(states[i:]),
68+
res_scores, 2, 6, task)
6569

6670
next_state = self.assertQueueEqual(
6771
res.queue, [0, 3, 2, 5], compute_score, states[2:],
@@ -71,6 +75,8 @@ def test_run_vizrank_interrupt(self):
7175
self.assertListEqual(res.scores, res_scores)
7276
self.assertIsNot(scores, res.scores)
7377
self.assertEqual(task.set_partial_result.call_count, 6)
78+
self.assertEqual(task.set_progress_value.call_count, 6)
79+
task.set_progress_value.assert_called_with(int(5 / 6 * 100))
7480

7581
def assertQueueEqual(self, queue, positions, f, states, next_states):
7682
self.assertIsInstance(queue, Queue)
@@ -96,7 +102,10 @@ def invoke_on_partial_result():
96102
widget.on_partial_result(run_vizrank(
97103
widget.compute_score,
98104
widget.iterate_states(widget.saved_state),
99-
widget.scores, task
105+
widget.scores,
106+
widget.saved_progress,
107+
widget.state_count(),
108+
task
100109
))
101110

102111
task = Mock()
@@ -107,6 +116,7 @@ def invoke_on_partial_result():
107116
widget.compute_score = compute_score
108117
widget.iterate_states = iterate_states
109118
widget.row_for_state = lambda sc, _: [QStandardItem(str(sc))]
119+
widget.state_count = lambda: len(states)
110120

111121
# interrupt calculation in third iteration
112122
task.is_interruption_requested.side_effect = lambda: \
@@ -117,6 +127,7 @@ def invoke_on_partial_result():
117127
sorted([compute_score(x) for x in states[:2]])):
118128
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
119129
self.assertEqual(widget.saved_progress, 2)
130+
task.set_progress_value.assert_called_with(int(1 / 6 * 100))
120131

121132
# continue calculation through all states
122133
task.is_interruption_requested.side_effect = lambda: False
@@ -126,6 +137,7 @@ def invoke_on_partial_result():
126137
sorted([compute_score(x) for x in states])):
127138
self.assertEqual(widget.rank_model.item(row, 0).text(), str(score))
128139
self.assertEqual(widget.saved_progress, 6)
140+
task.set_progress_value.assert_called_with(int(5 / 6 * 100))
129141

130142

131143
if __name__ == "__main__":

Orange/widgets/visualize/utils/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ def toggle(self):
366366
self.progressBarInit()
367367
self.before_running()
368368
self.start(run_vizrank, self.compute_score,
369-
self.iterate_states(self.saved_state), self.scores)
369+
self.iterate_states(self.saved_state), self.scores,
370+
self.saved_progress, self.state_count())
370371
else:
371372
self.button.setText("Continue")
372373
self.button.repaint()
@@ -382,8 +383,8 @@ def stopped(self):
382383
pass
383384

384385

385-
def run_vizrank(compute_score: Callable, states: Iterator,
386-
scores: List, task: TaskState):
386+
def run_vizrank(compute_score: Callable, states: Iterator, scores: List,
387+
saved_progress: int, state_count: int, task: TaskState):
387388
res = Result(queue=Queue(), scores=None)
388389
scores = scores.copy()
389390

@@ -402,10 +403,14 @@ def do_work(st, next_st):
402403

403404
state = None
404405
next_state = next(states)
406+
count = saved_progress
405407
try:
406408
while True:
407409
if task.is_interruption_requested():
408410
return res
411+
if state_count < 100 or count % (state_count // 100) == 0:
412+
task.set_progress_value(int(count * 100 / max(1, state_count)))
413+
count += 1
409414
state = copy.copy(next_state)
410415
next_state = copy.copy(next(states))
411416
do_work(state, next_state)

0 commit comments

Comments
 (0)