Skip to content

Commit

Permalink
lock by default - add safe_lock option for control (#322)
Browse files Browse the repository at this point in the history
* always lock; disable_lock option

* remove trailing space

* add test for new option

* rename disable_lock to safe_lock

* safe_lock is nthreads() > 1 by default

* declare p local to avoid warning in test

* fix test for initialization of safe_lock

* fix test for initialization of safe_lock (2)

* fix test
  • Loading branch information
lmiq authored Jul 11, 2024
1 parent e4b0b45 commit 66ad2be
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 16 deletions.
14 changes: 2 additions & 12 deletions src/ProgressMeter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Base.@kwdef mutable struct ProgressCore
numprintedvalues::Int = 0 # num values printed below progress in last iteration
prev_update_count::Int = 1 # counter at last update
printed::Bool = false # true if we have issued at least one status update
threads_used::Vector{Int} = Int[] # threads that have used this progress meter
safe_lock::Bool = Threads.nthreads() > 1 # set to false for non-threaded tight loops
tinit::Float64 = time() # time meter was initialized
tlast::Float64 = time() # time of last update
tsecond::Float64 = time() # ignore the first loop given usually uncharacteristically slow
Expand Down Expand Up @@ -441,17 +441,8 @@ end

predicted_updates_per_dt_have_passed(p::AbstractProgress) = p.counter - p.prev_update_count >= p.check_iterations

function is_threading(p::AbstractProgress)
Threads.nthreads() == 1 && return false
length(p.threads_used) > 1 && return true
if !in(Threads.threadid(), p.threads_used)
push!(p.threads_used, Threads.threadid())
end
return length(p.threads_used) > 1
end

function lock_if_threading(f::Function, p::AbstractProgress)
if is_threading(p)
if p.safe_lock
lock(p.lock) do
f()
end
Expand Down Expand Up @@ -817,7 +808,6 @@ function showprogressthreads(args...)
length($(esc(iters)));
$(showprogress_process_args(progressargs)...),
)
append!($(esc(p)).threads_used, 1:Threads.nthreads())
$(esc(expr))
finish!($(esc(p)))
end
Expand Down
16 changes: 16 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,19 @@ prog.n = UInt128(20) # in Progress
@test prog.n == 20
prog.offset = Int8(5) # in ProgressCore
@test prog.offset == 5

# Test safe_lock option, initialization and execution.
function simple_sum(n; safe_lock = true)
p = Progress(n; safe_lock)
s = 0.0
for i in 1:n
s += sin(i)^2
next!(p)
end
return s
end
p = Progress(10)
@test p.safe_lock == (Threads.nthreads() > 1)
p = Progress(10; safe_lock = false)
@test p.safe_lock == false
@test simple_sum(10; safe_lock = true) simple_sum(10; safe_lock = false)
1 change: 1 addition & 0 deletions test/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ finish!(prog)

println("Testing fractional bars")
for front in (['','','','','','', ''], ['' ,'' ,'' ,'' ,'' ,'', ''], ['', '', '',])
local p
p = Progress(100, dt=0.01, barglyphs=BarGlyphs('|','',front,' ','|'), barlen=10)
for i in 1:100
next!(p)
Expand Down
4 changes: 0 additions & 4 deletions test/test_threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
threadsUsed = fill(false, threads)
vals = ones(n*threads)
p = Progress(n*threads)
p.threads_used = 1:threads # short-circuit the function `is_threading` because it is racy (#232)
Threads.@threads for i = 1:(n*threads)
threadsUsed[Threads.threadid()] = true
vals[i] = 0
Expand All @@ -21,7 +20,6 @@
println("Testing ProgressUnknown() with Threads.@threads across $threads threads")
trigger = 100.0
prog = ProgressUnknown(desc="Attempts at exceeding trigger:")
prog.threads_used = 1:threads
vals = Float64[]
threadsUsed = fill(false, threads)
lk = ReentrantLock()
Expand All @@ -48,7 +46,6 @@
println("Testing ProgressThresh() with Threads.@threads across $threads threads")
thresh = 1.0
prog = ProgressThresh(thresh; desc="Minimizing:")
prog.threads_used = 1:threads
vals = fill(300.0, 1)
threadsUsed = fill(false, threads)
Threads.@threads for _ in 1:100000
Expand Down Expand Up @@ -76,7 +73,6 @@
# threadsUsed = fill(false, threads)
vals = ones(n*threads)
p = Progress(n*threads)
p.threads_used = 1:threads

for t in 1:threads
tasks[t] = Threads.@spawn for i in 1:n
Expand Down

0 comments on commit 66ad2be

Please sign in to comment.