From 807496a967944abed72be74f371c43cc876dbbec Mon Sep 17 00:00:00 2001 From: MarcMush <35898736+MarcMush@users.noreply.github.com> Date: Wed, 31 Jan 2024 01:30:22 +0100 Subject: [PATCH] fix `@showprogress @distributed` (#295) * fix macro * add test in global scope * don't forget to close while take!(ch) --- src/ProgressMeter.jl | 35 +++++++++++++---------------------- test/test.jl | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/src/ProgressMeter.jl b/src/ProgressMeter.jl index 4b03c62..045f1c1 100644 --- a/src/ProgressMeter.jl +++ b/src/ProgressMeter.jl @@ -796,47 +796,38 @@ function showprogressdistributed(args...) r = loop.args[1].args[2] body = loop.args[2] - setup = quote - n = length($(esc(r))) - p = Progress(n, $(showprogress_process_args(progressargs)...)) - ch = RemoteChannel(() -> Channel{Bool}(n)) - end - if na == 1 # would be nice to do this with @sync @distributed but @sync is broken # https://github.com/JuliaLang/julia/issues/28979 compute = quote - display = @async let i = 0 - while i < n - take!(ch) - next!(p) - i += 1 - end - end - @distributed for $(esc(var)) = $(esc(r)) + waiting = @distributed for $(esc(var)) = $(esc(r)) $(esc(body)) put!(ch, true) end + wait(waiting) nothing end else compute = quote - display = @async while take!(ch) next!(p) end - results = @distributed $(esc(reducer)) for $(esc(var)) = $(esc(r)) + @distributed $(esc(reducer)) for $(esc(var)) = $(esc(r)) x = $(esc(body)) put!(ch, true) x end - put!(ch, false) - results end end quote - $setup - results = $compute - wait(display) - results + let n = length($(esc(r))) + p = Progress(n, $(showprogress_process_args(progressargs)...)) + ch = RemoteChannel(() -> Channel{Bool}(n)) + + @async while take!(ch) next!(p) end + results = $compute + put!(ch, false) + finish!(p) + results + end end end diff --git a/test/test.jl b/test/test.jl index 0a86bd5..dc15b4c 100644 --- a/test/test.jl +++ b/test/test.jl @@ -372,6 +372,41 @@ end println("Testing @showprogress macro on distributed for loop without reducer") testfunc16(3000, 0.01, 0.001) +function testfunc16cb(N, dt, tsleep) + ProgressMeter.@showprogress dt=dt @distributed for i in N + if rand() < 0.7 + sleep(tsleep) + end + 200 < i < 400 && continue + i > 1500 && break + i ^ 2 + end +end + +println("Testing @showprogress macro on distributed for loop with continue") +testfunc16cb(1:1000, 0.01, 0.002) + +println("Testing @showprogress macro on distributed for loop with break") +testfunc16cb(1000:2000, 0.01, 0.003) + + +println("testing `@showprogress @distributed` in global scope") +@showprogress @distributed for i in 1:10 + sleep(0.1) + i^2 +end + +println("testing `@showprogress @distributed (+)` in global scope") +# https://github.com/timholy/ProgressMeter.jl/issues/243 +result = @showprogress @distributed (+) for i in 1:10 + sleep(0.1) + i^2 +end +@test result == sum(abs2, 1:10) + + + + function testfunc17() n = 30 p = ProgressMeter.Progress(n, start=15)