Skip to content

Commit 807496a

Browse files
authored
fix @showprogress @distributed (timholy#295)
* fix macro * add test in global scope * don't forget to close while take!(ch)
1 parent 65d049e commit 807496a

File tree

2 files changed

+48
-22
lines changed

2 files changed

+48
-22
lines changed

src/ProgressMeter.jl

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -796,47 +796,38 @@ function showprogressdistributed(args...)
796796
r = loop.args[1].args[2]
797797
body = loop.args[2]
798798

799-
setup = quote
800-
n = length($(esc(r)))
801-
p = Progress(n, $(showprogress_process_args(progressargs)...))
802-
ch = RemoteChannel(() -> Channel{Bool}(n))
803-
end
804-
805799
if na == 1
806800
# would be nice to do this with @sync @distributed but @sync is broken
807801
# https://github.com/JuliaLang/julia/issues/28979
808802
compute = quote
809-
display = @async let i = 0
810-
while i < n
811-
take!(ch)
812-
next!(p)
813-
i += 1
814-
end
815-
end
816-
@distributed for $(esc(var)) = $(esc(r))
803+
waiting = @distributed for $(esc(var)) = $(esc(r))
817804
$(esc(body))
818805
put!(ch, true)
819806
end
807+
wait(waiting)
820808
nothing
821809
end
822810
else
823811
compute = quote
824-
display = @async while take!(ch) next!(p) end
825-
results = @distributed $(esc(reducer)) for $(esc(var)) = $(esc(r))
812+
@distributed $(esc(reducer)) for $(esc(var)) = $(esc(r))
826813
x = $(esc(body))
827814
put!(ch, true)
828815
x
829816
end
830-
put!(ch, false)
831-
results
832817
end
833818
end
834819

835820
quote
836-
$setup
837-
results = $compute
838-
wait(display)
839-
results
821+
let n = length($(esc(r)))
822+
p = Progress(n, $(showprogress_process_args(progressargs)...))
823+
ch = RemoteChannel(() -> Channel{Bool}(n))
824+
825+
@async while take!(ch) next!(p) end
826+
results = $compute
827+
put!(ch, false)
828+
finish!(p)
829+
results
830+
end
840831
end
841832
end
842833

test/test.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,41 @@ end
372372
println("Testing @showprogress macro on distributed for loop without reducer")
373373
testfunc16(3000, 0.01, 0.001)
374374

375+
function testfunc16cb(N, dt, tsleep)
376+
ProgressMeter.@showprogress dt=dt @distributed for i in N
377+
if rand() < 0.7
378+
sleep(tsleep)
379+
end
380+
200 < i < 400 && continue
381+
i > 1500 && break
382+
i ^ 2
383+
end
384+
end
385+
386+
println("Testing @showprogress macro on distributed for loop with continue")
387+
testfunc16cb(1:1000, 0.01, 0.002)
388+
389+
println("Testing @showprogress macro on distributed for loop with break")
390+
testfunc16cb(1000:2000, 0.01, 0.003)
391+
392+
393+
println("testing `@showprogress @distributed` in global scope")
394+
@showprogress @distributed for i in 1:10
395+
sleep(0.1)
396+
i^2
397+
end
398+
399+
println("testing `@showprogress @distributed (+)` in global scope")
400+
# https://github.com/timholy/ProgressMeter.jl/issues/243
401+
result = @showprogress @distributed (+) for i in 1:10
402+
sleep(0.1)
403+
i^2
404+
end
405+
@test result == sum(abs2, 1:10)
406+
407+
408+
409+
375410
function testfunc17()
376411
n = 30
377412
p = ProgressMeter.Progress(n, start=15)

0 commit comments

Comments
 (0)