Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more complex comprehensions #302

Merged
merged 5 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ while true
next!(prog)
rand(1:2*10^8) == 1 && break
end
ProgressMeter.finish!(prog)
finish!(prog)
```

By default, `finish!` changes the spinner to a `✓`, but you can
Expand Down Expand Up @@ -421,7 +421,7 @@ p = Progress(n; enabled = false)
for iter in 1:10
x *= 2
sleep(0.5)
ProgressMeter.next!(p)
next!(p)
end
```

Expand Down
275 changes: 150 additions & 125 deletions src/ProgressMeter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,12 @@
t = time()
barlen = p.barlen isa Nothing ? tty_width(p.desc, p.output, p.showspeed) : p.barlen
percentage_complete = 100.0 * p.counter / p.n
percentage_rounded = 100
bar = barstring(barlen, percentage_complete, barglyphs=p.barglyphs)
elapsed_time = t - p.tinit
dur = durationstring(elapsed_time)
spacer = endswith(p.desc, " ") ? "" : " "
msg = @sprintf "%s%s%3u%%%s Time: %s" p.desc spacer round(Int, percentage_complete) bar dur
msg = @sprintf "%s%s%3u%%%s Time: %s" p.desc spacer percentage_rounded bar dur
if p.showspeed
sec_per_iter = elapsed_time / (p.counter - p.start)
msg = @sprintf "%s (%s)" msg speedstring(sec_per_iter)
Expand All @@ -292,6 +293,7 @@
if t > p.tlast+p.dt
barlen = p.barlen isa Nothing ? tty_width(p.desc, p.output, p.showspeed) : p.barlen
percentage_complete = 100.0 * p.counter / p.n
percentage_rounded = min(99, round(Int, percentage_complete)) # don't round up to 100% if not finished (#300)
bar = barstring(barlen, percentage_complete, barglyphs=p.barglyphs)
elapsed_time = t - p.tinit
est_total_time = elapsed_time * (p.n - p.start) / (p.counter - p.start)
Expand All @@ -302,7 +304,7 @@
eta = "N/A"
end
spacer = endswith(p.desc, " ") ? "" : " "
msg = @sprintf "%s%s%3u%%%s ETA: %s" p.desc spacer round(Int, percentage_complete) bar eta
msg = @sprintf "%s%s%3u%%%s ETA: %s" p.desc spacer percentage_rounded bar eta
if p.showspeed
sec_per_iter = elapsed_time / (p.counter - p.start)
msg = @sprintf "%s (%s)" msg speedstring(sec_per_iter)
Expand Down Expand Up @@ -780,10 +782,6 @@
progressargs = args[1:end-1]
expr = Base.remove_linenums!(args[end])

if expr.head != :macrocall || expr.args[1] != Symbol("@distributed")
throw(ArgumentError("malformed @showprogress @distributed expression"))
end

distargs = filter(x -> !(x isa LineNumberNode), expr.args[2:end])
na = length(distargs)
if na == 1
Expand Down Expand Up @@ -844,7 +842,7 @@
iters = loop.args[1].args[end]

p = gensym()
push!(loop.args[end].args, :(ProgressMeter.next!($p)))
push!(loop.args[end].args, :(next!($p)))

quote
$(esc(p)) = Progress(
Expand Down Expand Up @@ -888,146 +886,173 @@
end
progressargs = args[1:end-1]
expr = args[end]
if expr.head == :macrocall && expr.args[1] == Symbol("@distributed")
return showprogressdistributed(args...)
end
if expr.head == :macrocall && expr.args[1] == :(Threads.var"@threads")
return showprogressthreads(args...)

if !isa(expr, Expr)
throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, or a map-like function; got $expr"))

Check warning on line 891 in src/ProgressMeter.jl

View check run for this annotation

Codecov / codecov/patch

src/ProgressMeter.jl#L891

Added line #L891 was not covered by tests
end
orig = expr = copy(expr)
if expr.args[1] == :|> # e.g. map(x->x^2) |> sum

if expr.head == :call && expr.args[1] == :|>
# e.g. map(x->x^2) |> sum
expr.args[2] = showprogress(progressargs..., expr.args[2])
return expr

elseif expr.head in (:for, :comprehension, :typed_comprehension)
return showprogress_loop(expr, progressargs)

elseif expr.head == :call
return showprogress_map(expr, progressargs)

elseif expr.head == :do && expr.args[1].head == :call
return showprogress_map(expr, progressargs)

elseif expr.head == :macrocall
macroname = expr.args[1]

if macroname in (Symbol("@distributed"), :(Distributed.@distributed).args[1])
# can be changed to `:(Distributed.var"@distributed")` if support for pre-1.3 is dropped
return showprogressdistributed(args...)

elseif macroname in (Symbol("@threads"), :(Threads.@threads).args[1])
return showprogressthreads(args...)
end
end

throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, or a map-like function; got $expr"))

Check warning on line 920 in src/ProgressMeter.jl

View check run for this annotation

Codecov / codecov/patch

src/ProgressMeter.jl#L920

Added line #L920 was not covered by tests
end

function showprogress_map(expr, progressargs)
metersym = gensym("meter")
kind = :invalid # :invalid, :loop, or :map

if isa(expr, Expr)
if expr.head == :for
outerassignidx = 1
loopbodyidx = lastindex(expr.args)
kind = :loop
elseif expr.head == :comprehension
outerassignidx = lastindex(expr.args)
loopbodyidx = 1
kind = :loop
elseif expr.head == :typed_comprehension
outerassignidx = lastindex(expr.args)
loopbodyidx = 2
kind = :loop
elseif expr.head == :call
kind = :map
elseif expr.head == :do
call = expr.args[1]
if call.head == :call
kind = :map
end
end

# isolate call to map
if expr.head == :do
call = expr.args[1]
else
call = expr
end

if kind == :invalid
throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, or a map-like function; got $expr"))
elseif kind == :loop
# As of julia 0.5, a comprehension's "loop" is actually one level deeper in the syntax tree.
if expr.head !== :for
@assert length(expr.args) == loopbodyidx
expr = expr.args[outerassignidx] = copy(expr.args[outerassignidx])
@assert expr.head === :generator
outerassignidx = lastindex(expr.args)
loopbodyidx = 1
end
# get args to map to determine progress length
mapargs = collect(Any, filter(call.args[2:end]) do a
return isa(a, Symbol) || isa(a, Number) || !(a.head in (:kw, :parameters))
end)
if expr.head == :do
insert!(mapargs, 1, identity) # to make args for ncalls line up
end

# Transform the first loop assignment
loopassign = expr.args[outerassignidx] = copy(expr.args[outerassignidx])
if loopassign.head === :block # this will happen in a for loop with multiple iteration variables
for i in 2:length(loopassign.args)
loopassign.args[i] = esc(loopassign.args[i])
end
loopassign = loopassign.args[1] = copy(loopassign.args[1])
end
@assert loopassign.head === :(=)
@assert length(loopassign.args) == 2
obj = loopassign.args[2]
loopassign.args[1] = esc(loopassign.args[1])
loopassign.args[2] = :(ProgressWrapper(iterable, $(esc(metersym))))

# Transform the loop body break and return statements
if expr.head === :for
expr.args[loopbodyidx] = showprogress_process_expr(expr.args[loopbodyidx], metersym)
end
# change call to progress_map
mapfun = call.args[1]
call.args[1] = :progress_map

# Escape all args except the loop assignment, which was already appropriately escaped.
for i in 1:length(expr.args)
if i != outerassignidx
expr.args[i] = esc(expr.args[i])
end
end
if orig !== expr
# We have additional escaping to do; this will occur for comprehensions with julia 0.5 or later.
for i in 1:length(orig.args)-1
orig.args[i] = esc(orig.args[i])
end
end
# escape args as appropriate
for i in 2:length(call.args)
call.args[i] = esc(call.args[i])
end
if expr.head == :do
expr.args[2] = esc(expr.args[2])
end

setup = quote
iterable = $(esc(obj))
$(esc(metersym)) = Progress(length(iterable), $(showprogress_process_args(progressargs)...))
end
# create appropriate Progress expression
lenex = :(ncalls($(esc(mapfun)), $(esc.(mapargs)...)))
progex = :(Progress($lenex, $(showprogress_process_args(progressargs)...)))

if expr.head === :for
return quote
$setup
$expr
end
else
# We're dealing with a comprehension
return quote
begin
$setup
rv = $orig
next!($(esc(metersym)))
rv
end
end
# insert progress and mapfun kwargs
push!(call.args, Expr(:kw, :progress, progex))
push!(call.args, Expr(:kw, :mapfun, esc(mapfun)))

return expr
end

function showprogress_loop(expr, progressargs)
metersym = gensym("meter")
orig = expr = copy(expr)

if expr.head == :for
outerassignidx = 1
loopbodyidx = lastindex(expr.args)
elseif expr.head == :comprehension
outerassignidx = lastindex(expr.args)
loopbodyidx = 1
elseif expr.head == :typed_comprehension
outerassignidx = lastindex(expr.args)
loopbodyidx = 2
end
# As of julia 0.5, a comprehension's "loop" is actually one level deeper in the syntax tree.
if expr.head !== :for
@assert length(expr.args) == loopbodyidx
expr = expr.args[outerassignidx] = copy(expr.args[outerassignidx])
if expr.head == :flatten
# e.g. [x for x in 1:10 for y in 1:x]
expr = expr.args[1] = copy(expr.args[1])
end
else # kind == :map
@assert expr.head === :generator
outerassignidx = lastindex(expr.args)
loopbodyidx = 1
end

# isolate call to map
if expr.head == :do
call = expr.args[1]
else
call = expr
# Transform the first loop assignment
loopassign = expr.args[outerassignidx] = copy(expr.args[outerassignidx])

if loopassign.head === :filter
# e.g. [x for x=1:10, y=1:10 if x>y]
# y will be wrapped in ProgressWrapper
for i in 1:length(loopassign.args)-1
loopassign.args[i] = esc(loopassign.args[i])
end
loopassign = loopassign.args[end] = copy(loopassign.args[end])
end

# get args to map to determine progress length
mapargs = collect(Any, filter(call.args[2:end]) do a
return isa(a, Symbol) || isa(a, Number) || !(a.head in (:kw, :parameters))
end)
if expr.head == :do
insert!(mapargs, 1, identity) # to make args for ncalls line up
if loopassign.head === :block
# e.g. for x=1:10, y=1:x end
# x will be wrapped in ProgressWrapper
for i in 2:length(loopassign.args)
loopassign.args[i] = esc(loopassign.args[i])
end
loopassign = loopassign.args[1] = copy(loopassign.args[1])
end

@assert loopassign.head === :(=)
@assert length(loopassign.args) == 2
obj = loopassign.args[2]
loopassign.args[1] = esc(loopassign.args[1])
loopassign.args[2] = :(ProgressWrapper(iterable, $(esc(metersym))))

# change call to progress_map
mapfun = call.args[1]
call.args[1] = :progress_map
# Transform the loop body break and return statements
if expr.head === :for
expr.args[loopbodyidx] = showprogress_process_expr(expr.args[loopbodyidx], metersym)
end

# escape args as appropriate
for i in 2:length(call.args)
call.args[i] = esc(call.args[i])
# Escape all args except the loop assignment, which was already appropriately escaped.
for i in 1:length(expr.args)
if i != outerassignidx
expr.args[i] = esc(expr.args[i])
end
if expr.head == :do
expr.args[2] = esc(expr.args[2])
end
if orig !== expr
# We have additional escaping to do; this will occur for comprehensions with julia 0.5 or later.
for i in 1:length(orig.args)-1
orig.args[i] = esc(orig.args[i])
end
end

# create appropriate Progress expression
lenex = :(ncalls($(esc(mapfun)), $(esc.(mapargs)...)))
progex = :(Progress($lenex, $(showprogress_process_args(progressargs)...)))

# insert progress and mapfun kwargs
push!(call.args, Expr(:kw, :progress, progex))
push!(call.args, Expr(:kw, :mapfun, esc(mapfun)))
setup = quote
iterable = $(esc(obj))
$(esc(metersym)) = Progress(length(iterable), $(showprogress_process_args(progressargs)...))
end

return expr
if expr.head === :for
return quote
$setup
$expr
end
else
# We're dealing with a comprehension
return quote
begin
$setup

Check warning on line 1050 in src/ProgressMeter.jl

View check run for this annotation

Codecov / codecov/patch

src/ProgressMeter.jl#L1050

Added line #L1050 was not covered by tests
rv = $orig
finish!($(esc(metersym)))
rv
end
end
end
end

Expand Down
4 changes: 2 additions & 2 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
@test ProgressMeter.durationstring(60*60*24*10 - 0.1) == "9 days, 23:59:59"
@test ProgressMeter.durationstring(60*60*24*10) == "10.00 days"

@test ProgressMeter.Progress(5, desc="Progress:", offset=Int16(5)).offset == 5
@test ProgressMeter.ProgressThresh(0.2, desc="Progress:", offset=Int16(5)).offset == 5
@test Progress(5, desc="Progress:", offset=Int16(5)).offset == 5
@test ProgressThresh(0.2, desc="Progress:", offset=Int16(5)).offset == 5

# test speed string formatting
for ns in [1, 9, 10, 99, 100, 999, 1_000, 9_999, 10_000, 99_000, 100_000, 999_999, 1_000_000, 9_000_000, 10_000_000, 99_999_000, 1_234_567_890, 1_234_567_890 * 10, 1_234_567_890 * 100, 1_234_567_890 * 1_000, 1_234_567_890 * 10_000, 1_234_567_890 * 100_000, 1_234_567_890 * 1_000_000, 1_234_567_890 * 10_000_000]
Expand Down
Loading
Loading