diff --git a/README.md b/README.md index 4fa40dd..850cf09 100644 --- a/README.md +++ b/README.md @@ -416,7 +416,7 @@ end It possible to disable the progress meter when the use is optional. ```julia -x,n = 1,10 +x, n = 1, 10 p = Progress(n; enabled = false) for iter in 1:10 x *= 2 @@ -431,7 +431,25 @@ In cases where the output is text output such as CI or in an HPC scheduler, the ```julia is_logging(io) = isa(io, Base.TTY) == false || (get(ENV, "CI", nothing) == "true") p = Progress(n; output = stderr, enabled = !is_logging(stderr)) -```` +``` + +### Adding support for more map-like functions + +To add support for other functions, `ProgressMeter.ncalls` must be defined, +where `ncalls_map`, `ncalls_broadcast`, `ncalls_broadcast!` or `ncalls_reduce` can help + +For example, with `tmap` from [`ThreadTools.jl`](https://github.com/baggepinnen/ThreadTools.jl): + +```julia +using ThreadTools, ProgressMeter + +ProgressMeter.ncalls(::typeof(tmap), ::Function, args...) = ProgressMeter.ncalls_map(args...) +ProgressMeter.ncalls(::typeof(tmap), ::Function, ::Int, args...) = ProgressMeter.ncalls_map(args...) + +@showprogress tmap(abs2, 1:10^5) +@showprogress tmap(abs2, 4, 1:10^5) +``` + ## Development/debugging tips diff --git a/src/ProgressMeter.jl b/src/ProgressMeter.jl index 4b03c62..230970b 100644 --- a/src/ProgressMeter.jl +++ b/src/ProgressMeter.jl @@ -761,7 +761,7 @@ end """ Equivalent of @showprogress for a distributed for loop. ``` -result = @showprogress dt "Computing..." @distributed (+) for i = 1:50 +result = @showprogress @distributed (+) for i = 1:50 sleep(0.1) i^2 end @@ -852,9 +852,14 @@ displays progress in performing a computation. You may optionally supply a custom message to be printed that specifies the computation being performed or other options. -`@showprogress` works for loops, comprehensions, `asyncmap`, -`broadcast`, `broadcast!`, `foreach`, `map`, `mapfoldl`, -`mapfoldr`, `mapreduce`, `pmap` and `reduce`. +`@showprogress` works for loops, comprehensions, and `map`-like +functions. These `map`-like functions rely on `ncalls` being defined +and can be checked with `methods(ProgressMeter.ncalls)`. New ones can +be added by defining `ProgressMeter.ncalls(::typeof(mapfun), args...) = ...`. + +`@showprogress` is thread-safe and will work with `@distributed` loops +as well as threaded or distributed functions like `pmap` and `asyncmap`. + """ macro showprogress(args...) showprogress(args...) @@ -875,8 +880,6 @@ function showprogress(args...) return expr end metersym = gensym("meter") - mapfuns = (:asyncmap, :broadcast, :broadcast!, :foreach, :map, - :mapfoldl, :mapfoldr, :mapreduce, :pmap, :reduce) kind = :invalid # :invalid, :loop, or :map if isa(expr, Expr) @@ -892,18 +895,18 @@ function showprogress(args...) outerassignidx = lastindex(expr.args) loopbodyidx = 2 kind = :loop - elseif expr.head == :call && expr.args[1] in mapfuns + elseif expr.head == :call kind = :map elseif expr.head == :do call = expr.args[1] - if call.head == :call && call.args[1] in mapfuns + if call.head == :call kind = :map end end end if kind == :invalid - throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, map, reduce, or pmap; got $expr")) + throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, or a map-like function; got $expr")) elseif kind == :loop # As of julia 0.5, a comprehension's "loop" is actually one level deeper in the syntax tree. if expr.head !== :for @@ -981,7 +984,7 @@ function showprogress(args...) return isa(a, Symbol) || isa(a, Number) || !(a.head in (:kw, :parameters)) end) if expr.head == :do - insert!(mapargs, 1, :nothing) # to make args for ncalls line up + insert!(mapargs, 1, identity) # to make args for ncalls line up end # change call to progress_map @@ -997,7 +1000,7 @@ function showprogress(args...) end # create appropriate Progress expression - lenex = :(ncalls($(esc(mapfun)), ($([esc(a) for a in mapargs]...),))) + lenex = :(ncalls($(esc(mapfun)), $(esc.(mapargs)...))) progex = :(Progress($lenex, $(showprogress_process_args(progressargs)...))) # insert progress and mapfun kwargs @@ -1014,10 +1017,12 @@ end Run a `map`-like function while displaying progress. `mapfun` can be any function, but it is only tested with `map`, `reduce` and `pmap`. +`ProgressMeter.ncalls(::typeof(mapfun), ::Function, args...)` must be defined to +specify the number of calls to `f`. """ function progress_map(args...; mapfun=map, - progress=Progress(ncalls(mapfun, args)), - channel_bufflen=min(1000, ncalls(mapfun, args)), + progress=Progress(ncalls(mapfun, args...)), + channel_bufflen=min(1000, ncalls(mapfun, args...)), kwargs...) isempty(args) && return mapfun(; kwargs...) f = first(args) @@ -1052,36 +1057,50 @@ Run `pmap` while displaying progress. progress_pmap(args...; kwargs...) = progress_map(args...; mapfun=pmap, kwargs...) """ -Infer the number of calls to the mapped function (i.e. the length of the returned array) given the input arguments to map, reduce or pmap. + ProgressMeter.ncalls(::typeof(mapfun), ::Function, args...) + +Infer the number of calls to the mapped function (often the length of the returned array) +to define the length of the `Progress` in `@showprogress` and `progress_map`. +Internally uses one of `ncalls_map`, `ncalls_broadcast(!)` or `ncalls_reduce` depending +on the type of `mapfun`. + +Support for additional functions can be added by defining +`ProgressMeter.ncalls(::typeof(mapfun), ::Function, args...)`. """ -function ncalls(::typeof(broadcast), map_args) - length(map_args) < 2 && return 1 - return prod(length, Broadcast.combine_axes(map_args[2:end]...)) -end +ncalls(::typeof(map), ::Function, args...) = ncalls_map(args...) +ncalls(::typeof(map!), ::Function, args...) = ncalls_map(args...) +ncalls(::typeof(foreach), ::Function, args...) = ncalls_map(args...) +ncalls(::typeof(asyncmap), ::Function, args...) = ncalls_map(args...) -function ncalls(::typeof(broadcast!), map_args) - length(map_args) < 2 && return 1 - return length(map_args[2]) -end +ncalls(::typeof(pmap), ::Function, args...) = ncalls_map(args...) +ncalls(::typeof(pmap), ::Function, ::AbstractWorkerPool, args...) = ncalls_map(args...) -function ncalls(::Union{typeof(mapreduce),typeof(mapfoldl),typeof(mapfoldr)}, map_args) - length(map_args) < 3 && return 1 - return minimum(length, map_args[3:end]) +ncalls(::typeof(mapfoldl), ::Function, ::Function, args...) = ncalls_map(args...) +ncalls(::typeof(mapfoldr), ::Function, ::Function, args...) = ncalls_map(args...) +ncalls(::typeof(mapreduce), ::Function, ::Function, args...) = ncalls_map(args...) + +ncalls(::typeof(broadcast), ::Function, args...) = ncalls_broadcast(args...) +ncalls(::typeof(broadcast!), ::Function, args...) = ncalls_broadcast!(args...) + +ncalls(::typeof(foldl), ::Function, arg) = ncalls_reduce(arg) +ncalls(::typeof(foldr), ::Function, arg) = ncalls_reduce(arg) +ncalls(::typeof(reduce), ::Function, arg) = ncalls_reduce(arg) + +ncalls_reduce(arg) = length(arg) - 1 + +function ncalls_broadcast(args...) + length(args) < 1 && return 1 + return prod(length, Broadcast.combine_axes(args...)) end -function ncalls(::typeof(pmap), map_args) - if length(map_args) ≥ 2 && map_args[2] isa AbstractWorkerPool - length(map_args) < 3 && return 1 - return minimum(length, map_args[3:end]) - else - length(map_args) < 2 && return 1 - return minimum(length, map_args[2:end]) - end +function ncalls_broadcast!(args...) + length(args) < 1 && return 1 + return length(args[1]) end -function ncalls(mapfun::Function, map_args) - length(map_args) < 2 && return 1 - return minimum(length, map_args[2:end]) +function ncalls_map(args...) + length(args) < 1 && return 1 + return minimum(length, args) end include("deprecated.jl") diff --git a/test/test_map.jl b/test/test_map.jl index 24fb40b..a57ebbf 100644 --- a/test/test_map.jl +++ b/test/test_map.jl @@ -54,28 +54,52 @@ wp = WorkerPool(procs) println() # test ncalls - @test ncalls(map, (+, 1:10)) == 10 - @test ncalls(pmap, (+, 1:10, 1:100)) == 10 - @test ncalls(pmap, (+, wp, 1:10)) == 10 - @test ncalls(reduce, (+, 1:10)) == 10 - @test ncalls(mapreduce, (+, +, 1:10, (1:10)')) == 10 - @test ncalls(mapfoldl, (+, +, 1:10, (1:10)')) == 10 - @test ncalls(mapfoldr, (+, +, 1:10, (1:10)')) == 10 - @test ncalls(foreach, (+, 1:10)) == 10 - @test ncalls(broadcast, (+, 1:10, 1:10)) == 10 - @test ncalls(broadcast, (+, 1:8, (1:7)', 1)) == 8*7 - @test ncalls(broadcast, (+, 1:3, (1:5)', ones(1,1,2))) == 3*5*2 - @test ncalls(broadcast!, (+, zeros(10,8))) == 80 - @test ncalls(broadcast!, (+, zeros(10,8,7), 1:10)) == 10*8*7 - - @test ncalls(map, (time,)) == 1 - @test ncalls(foreach, (time,)) == 1 - @test ncalls(broadcast, (time,)) == 1 - @test ncalls(broadcast!, (time, [1])) == 1 - @test ncalls(mapreduce, (time, +)) == 1 - - @test_throws DimensionMismatch ncalls(broadcast, (+, 1:10, 1:100)) - @test_throws DimensionMismatch ncalls(broadcast, (+, 1:100, 1:10)) + @test ncalls(map, +, 1:10) == 10 + @test ncalls(pmap, +, 1:10, 1:100) == 10 + @test ncalls(pmap, +, wp, 1:10) == 10 + @test ncalls(foldr, +, 1:10) == 9 + @test ncalls(foldl, +, 1:10) == 9 + @test ncalls(reduce, +, 1:10) == 9 + @test ncalls(mapreduce, +, +, 1:10, (1:10)') == 10 + @test ncalls(mapfoldl, +, +, 1:10, (1:10)') == 10 + @test ncalls(mapfoldr, +, +, 1:10, (1:10)') == 10 + @test ncalls(foreach, +, 1:10) == 10 + @test ncalls(broadcast, +, 1:10, 1:10) == 10 + @test ncalls(broadcast, +, 1:8, (1:7)', 1) == 8*7 + @test ncalls(broadcast, +, 1:3, (1:5)', ones(1,1,2)) == 3*5*2 + @test ncalls(broadcast!, +, zeros(10,8,7), 1:10) == 10*8*7 + + # functions with no args + # map(f) and foreach(f) were removed (#291) + @test ncalls(broadcast, time) == 1 + @test ncalls(broadcast!, time, [1]) == 1 + @test ncalls(broadcast!, time, zeros(10,8)) == 80 + @test ncalls(mapreduce, time, +) == 1 + + @test_throws DimensionMismatch ncalls(broadcast, +, 1:10, 1:100) + @test_throws DimensionMismatch ncalls(broadcast, +, 1:100, 1:10) + + @test_throws MethodError ncalls(map, 1:10, 1:10) + @test_throws MethodError @showprogress map(1:10, 1:10) + + # test custom mapfun + mymap(f, x) = map(f, [x ; x]) + @test_throws MethodError ncalls(mymap, +, 1:10) + @test_throws MethodError @showprogress mymap(+, 1:10) + + ProgressMeter.ncalls(::typeof(mymap), ::Function, args...) = 2*ProgressMeter.ncalls_map(args...) + @test ncalls(mymap, +, 1:10) == 20 + + println("Testing custom map") + vals = @showprogress mymap(1:10) do x + sleep(0.1) + return x^2 + end + @test vals == map(x->x^2, [1:10; 1:10]) + + println("Testing custom map with kwarg (color red)") + vals = @showprogress color=:red mymap(x->(sleep(0.1); x^2), 1:10) + @test vals == map(x->x^2, [1:10; 1:10]) # @showprogress vals = @showprogress map(1:10) do x @@ -137,9 +161,7 @@ wp = WorkerPool(procs) return x end @test A == repeat(1:10, 1, 8) - - - + # function passed by name function testfun(x) return x^2 @@ -172,7 +194,6 @@ wp = WorkerPool(procs) @test broadcast(constfun) == @showprogress broadcast(constfun) #@test mapreduce(constfun, error) == @showprogress mapreduce(constfun, error) # julia 1.2+ - # #136: make sure mid progress shows up even without sleep println("Verify that intermediate progress is displayed:") @showprogress map(1:100) do i @@ -184,25 +205,15 @@ wp = WorkerPool(procs) vals = @showprogress pmap((x,y)->x*y, 1:10, 2:11) @test vals == map((x,y)->x*y, 1:10, 2:11) - - - - - - # Progress args vals = @showprogress dt=0.1 desc="Computing" pmap(testfun, 1:10) @test vals == map(testfun, 1:10) - - # named vector arg a = collect(1:10) vals = @showprogress pmap(x->x^2, a) @test vals == map(x->x^2, a) - - # global variable in do C = 10 vals = @showprogress pmap(1:10) do x @@ -210,8 +221,6 @@ wp = WorkerPool(procs) end @test vals == map(x->C*x, 1:10) - - # keyword arguments vals = @showprogress pmap(x->x^2, 1:100, batch_size=10) @test vals == map(x->x^2, 1:100) @@ -219,6 +228,11 @@ wp = WorkerPool(procs) vals = @showprogress pmap(x->x^2, 1:100; batch_size=10) @test vals == map(x->x^2, 1:100) + A = rand(0:999, 7, 11, 13) + vals = @showprogress mapreduce(abs2, +, A; dims=1, init=0) + @test vals == mapreduce(abs2, +, A; dims=1, init=0) + vals = @showprogress mapfoldl(abs2, -, A; init=1) + @test vals == mapfoldl(abs2, -, A; init=1) # pipes after map @showprogress map(testfun, 1:10) |> sum |> length