From cd0414fa75b51fae39c16436021002cfee21bb88 Mon Sep 17 00:00:00 2001 From: MarcMush <35898736+MarcMush@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:47:29 +0200 Subject: [PATCH] Return of the threaded detection (#328) --- src/ProgressMeter.jl | 17 ++++++++- test/core.jl | 90 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 92 insertions(+), 15 deletions(-) diff --git a/src/ProgressMeter.jl b/src/ProgressMeter.jl index d9923e6..15a9e4a 100644 --- a/src/ProgressMeter.jl +++ b/src/ProgressMeter.jl @@ -81,7 +81,8 @@ Base.@kwdef mutable struct ProgressCore numprintedvalues::Int = 0 # num values printed below progress in last iteration prev_update_count::Int = 1 # counter at last update printed::Bool = false # true if we have issued at least one status update - safe_lock::Bool = Threads.nthreads() > 1 # set to false for non-threaded tight loops + safe_lock::Int = 2*(Threads.nthreads()>1) # 0: no lock, 1: lock, 2: detect + thread_id::Int = Threads.threadid() # id of the thread that created the progressmeter tinit::Float64 = time() # time meter was initialized tlast::Float64 = time() # time of last update tsecond::Float64 = time() # ignore the first loop given usually uncharacteristically slow @@ -448,8 +449,20 @@ end predicted_updates_per_dt_have_passed(p::AbstractProgress) = p.counter - p.prev_update_count >= p.check_iterations +function is_threading(p::AbstractProgress) + p.safe_lock == 0 && return false + p.safe_lock == 1 && return true + if p.thread_id != Threads.threadid() + lock(p.lock) do + p.safe_lock = 1 + end + return true + end + return false +end + function lock_if_threading(f::Function, p::AbstractProgress) - if p.safe_lock + if is_threading(p) lock(p.lock) do f() end diff --git a/test/core.jl b/test/core.jl index cf41c7d..0785b5d 100644 --- a/test/core.jl +++ b/test/core.jl @@ -25,7 +25,7 @@ for ns in [1, 9, 10, 99, 100, 999, 1_000, 9_999, 10_000, 99_000, 100_000, 999_99 end # Performance test (from #171, #323) -function prog_perf(n; dt=0.1, enabled=true, force=false, safe_lock=false) +function prog_perf(n; dt=0.1, enabled=true, force=false, safe_lock=0) prog = Progress(n; dt, enabled, safe_lock) x = 0.0 for i in 1:n @@ -43,38 +43,85 @@ function noprog_perf(n) return x end +function prog_threaded(n; dt=0.1, enabled=true, force=false, safe_lock=2) + prog = Progress(n; dt, enabled, safe_lock) + x = Threads.Atomic{Float64}(0.0) + Threads.@threads for i in 1:n + Threads.atomic_add!(x, rand()) + next!(prog; force) + end + return x +end + +function noprog_threaded(n) + x = Threads.Atomic{Float64}(0.0) + Threads.@threads for i in 1:n + Threads.atomic_add!(x, rand()) + end + return x +end + println("Performance tests...") #precompile noprog_perf(10) prog_perf(10) -prog_perf(10; safe_lock=true) -prog_perf(10; dt=9999) +prog_perf(10; safe_lock=1) +prog_perf(10; dt=9999.9) prog_perf(10; enabled=false) -prog_perf(10; enabled=false, safe_lock=true) +prog_perf(10; enabled=false, safe_lock=1) prog_perf(10; force=true) -t_noprog = (@elapsed noprog_perf(10^8))/10^8 -t_prog = (@elapsed prog_perf(10^8))/10^8 -t_lock = (@elapsed prog_perf(10^8; safe_lock=true))/10^8 -t_noprint = (@elapsed prog_perf(10^8; dt=9999))/10^8 -t_disabled = (@elapsed prog_perf(10^8; enabled=false))/10^8 -t_disabled_lock = (@elapsed prog_perf(10^8; enabled=false, safe_lock=true))/10^8 -t_force = (@elapsed prog_perf(10^2; force=true))/10^2 +noprog_threaded(2*Threads.nthreads()) +prog_threaded(2*Threads.nthreads()) +prog_threaded(2*Threads.nthreads(); safe_lock=1) +prog_threaded(2*Threads.nthreads(); dt=9999) +prog_threaded(2*Threads.nthreads(); enabled=false) +prog_threaded(2*Threads.nthreads(); force=true) + +N = 10^8 +N_force = 1000 +t_noprog = (@elapsed noprog_perf(N))/N +t_prog = (@elapsed prog_perf(N))/N +t_lock = (@elapsed prog_perf(N; safe_lock=1))/N +t_detect = (@elapsed prog_perf(N; safe_lock=2))/N +t_noprint = (@elapsed prog_perf(N; dt=9999.9))/N +t_disabled = (@elapsed prog_perf(N; enabled=false))/N +t_disabled_lock = (@elapsed prog_perf(N; enabled=false, safe_lock=1))/N +t_force = (@elapsed prog_perf(N_force; force=true))/N_force + +Nth = Threads.nthreads() * 10^6 +Nth_force = Threads.nthreads() * 100 +th_noprog = (@elapsed noprog_threaded(Nth))/Nth +th_detect = (@elapsed prog_threaded(Nth))/Nth +th_lock = (@elapsed prog_threaded(Nth; safe_lock=1))/Nth +th_noprint = (@elapsed prog_threaded(Nth; dt=9999.9))/Nth +th_disabled = (@elapsed prog_threaded(Nth; enabled=false))/Nth +th_force = (@elapsed prog_threaded(Nth_force; force=true))/Nth_force println("Performance results:") println("without progress: ", ProgressMeter.speedstring(t_noprog)) -println("with defaults: ", ProgressMeter.speedstring(t_prog)) +println("with no lock: ", ProgressMeter.speedstring(t_prog)) println("with no printing: ", ProgressMeter.speedstring(t_noprint)) println("with disabled: ", ProgressMeter.speedstring(t_disabled)) println("with lock: ", ProgressMeter.speedstring(t_lock)) +println("with automatic lock: ", ProgressMeter.speedstring(t_detect)) println("with lock, disabled: ", ProgressMeter.speedstring(t_disabled_lock)) println("with force: ", ProgressMeter.speedstring(t_force)) +println() +println("Threaded performance results: ($(Threads.nthreads()) threads)") +println("without progress: ", ProgressMeter.speedstring(th_noprog)) +println("with automatic lock: ", ProgressMeter.speedstring(th_detect)) +println("with forced lock: ", ProgressMeter.speedstring(th_lock)) +println("with no printing: ", ProgressMeter.speedstring(th_noprint)) +println("with disabled: ", ProgressMeter.speedstring(th_disabled)) +println("with force: ", ProgressMeter.speedstring(th_force)) if get(ENV, "CI", "false") == "false" # CI environment is too unreliable for performance tests @test t_prog < 9*t_noprog end + # Avoid a NaN due to the estimated print time compensation # https://github.com/timholy/ProgressMeter.jl/issues/209 prog = Progress(10) @@ -116,7 +163,24 @@ function simple_sum(n; safe_lock = true) return s end p = Progress(10) -@test p.safe_lock == (Threads.nthreads() > 1) +@test (p.safe_lock) == 2*(Threads.nthreads() > 1) p = Progress(10; safe_lock = false) @test p.safe_lock == false @test simple_sum(10; safe_lock = true) ≈ simple_sum(10; safe_lock = false) + + +# Brute-force thread safety + +function test_thread(N) + p = Progress(N) + Threads.@threads for _ in 1:N + next!(p) + end +end + +println("Brute-forcing thread safety... ($(Threads.nthreads()) threads)") +@time for i in 1:10^5 + test_thread(Threads.nthreads()) +end + +