Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow for new map-like functions #296

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ end
It possible to disable the progress meter when the use is optional.

```julia
x,n = 1,10
x, n = 1, 10
p = Progress(n; enabled = false)
for iter in 1:10
x *= 2
Expand All @@ -431,7 +431,25 @@ In cases where the output is text output such as CI or in an HPC scheduler, the
```julia
is_logging(io) = isa(io, Base.TTY) == false || (get(ENV, "CI", nothing) == "true")
p = Progress(n; output = stderr, enabled = !is_logging(stderr))
````
```

### Adding support for more map-like functions

To add support for other functions, `ProgressMeter.ncalls` must be defined,
where `ncalls_map`, `ncalls_broadcast`, `ncalls_broadcast!` or `ncalls_reduce` can help

For example, with `tmap` from [`ThreadTools.jl`](https://github.com/baggepinnen/ThreadTools.jl):

```julia
using ThreadTools, ProgressMeter

ProgressMeter.ncalls(::typeof(tmap), ::Function, args...) = ProgressMeter.ncalls_map(args...)
ProgressMeter.ncalls(::typeof(tmap), ::Function, ::Int, args...) = ProgressMeter.ncalls_map(args...)

@showprogress tmap(abs2, 1:10^5)
@showprogress tmap(abs2, 4, 1:10^5)
```


## Development/debugging tips

Expand Down
91 changes: 55 additions & 36 deletions src/ProgressMeter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@
"""
Equivalent of @showprogress for a distributed for loop.
```
result = @showprogress dt "Computing..." @distributed (+) for i = 1:50
result = @showprogress @distributed (+) for i = 1:50
sleep(0.1)
i^2
end
Expand Down Expand Up @@ -852,9 +852,14 @@
supply a custom message to be printed that specifies the computation
being performed or other options.

`@showprogress` works for loops, comprehensions, `asyncmap`,
`broadcast`, `broadcast!`, `foreach`, `map`, `mapfoldl`,
`mapfoldr`, `mapreduce`, `pmap` and `reduce`.
`@showprogress` works for loops, comprehensions, and `map`-like
functions. These `map`-like functions rely on `ncalls` being defined
and can be checked with `methods(ProgressMeter.ncalls)`. New ones can
be added by defining `ProgressMeter.ncalls(::typeof(mapfun), args...) = ...`.

`@showprogress` is thread-safe and will work with `@distributed` loops
as well as threaded or distributed functions like `pmap` and `asyncmap`.

Check warning on line 861 in src/ProgressMeter.jl

View check run for this annotation

Codecov / codecov/patch

src/ProgressMeter.jl#L861

Added line #L861 was not covered by tests

"""
macro showprogress(args...)
showprogress(args...)
Expand All @@ -875,8 +880,6 @@
return expr
end
metersym = gensym("meter")
mapfuns = (:asyncmap, :broadcast, :broadcast!, :foreach, :map,
:mapfoldl, :mapfoldr, :mapreduce, :pmap, :reduce)
kind = :invalid # :invalid, :loop, or :map

if isa(expr, Expr)
Expand All @@ -892,18 +895,18 @@
outerassignidx = lastindex(expr.args)
loopbodyidx = 2
kind = :loop
elseif expr.head == :call && expr.args[1] in mapfuns
elseif expr.head == :call
kind = :map
elseif expr.head == :do
call = expr.args[1]
if call.head == :call && call.args[1] in mapfuns
if call.head == :call
kind = :map
end
end
end

if kind == :invalid
throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, map, reduce, or pmap; got $expr"))
throw(ArgumentError("Final argument to @showprogress must be a for loop, comprehension, or a map-like function; got $expr"))
elseif kind == :loop
# As of julia 0.5, a comprehension's "loop" is actually one level deeper in the syntax tree.
if expr.head !== :for
Expand Down Expand Up @@ -981,7 +984,7 @@
return isa(a, Symbol) || isa(a, Number) || !(a.head in (:kw, :parameters))
end)
if expr.head == :do
insert!(mapargs, 1, :nothing) # to make args for ncalls line up
insert!(mapargs, 1, identity) # to make args for ncalls line up
end

# change call to progress_map
Expand All @@ -997,7 +1000,7 @@
end

# create appropriate Progress expression
lenex = :(ncalls($(esc(mapfun)), ($([esc(a) for a in mapargs]...),)))
lenex = :(ncalls($(esc(mapfun)), $(esc.(mapargs)...)))
progex = :(Progress($lenex, $(showprogress_process_args(progressargs)...)))

# insert progress and mapfun kwargs
Expand All @@ -1014,10 +1017,12 @@
Run a `map`-like function while displaying progress.

`mapfun` can be any function, but it is only tested with `map`, `reduce` and `pmap`.
`ProgressMeter.ncalls(::typeof(mapfun), ::Function, args...)` must be defined to
specify the number of calls to `f`.
"""
function progress_map(args...; mapfun=map,
progress=Progress(ncalls(mapfun, args)),
channel_bufflen=min(1000, ncalls(mapfun, args)),
progress=Progress(ncalls(mapfun, args...)),
channel_bufflen=min(1000, ncalls(mapfun, args...)),
kwargs...)
isempty(args) && return mapfun(; kwargs...)
f = first(args)
Expand Down Expand Up @@ -1052,36 +1057,50 @@
progress_pmap(args...; kwargs...) = progress_map(args...; mapfun=pmap, kwargs...)

"""
Infer the number of calls to the mapped function (i.e. the length of the returned array) given the input arguments to map, reduce or pmap.
ProgressMeter.ncalls(::typeof(mapfun), ::Function, args...)

Infer the number of calls to the mapped function (often the length of the returned array)

Check warning on line 1062 in src/ProgressMeter.jl

View check run for this annotation

Codecov / codecov/patch

src/ProgressMeter.jl#L1062

Added line #L1062 was not covered by tests
to define the length of the `Progress` in `@showprogress` and `progress_map`.
Internally uses one of `ncalls_map`, `ncalls_broadcast(!)` or `ncalls_reduce` depending
on the type of `mapfun`.

Support for additional functions can be added by defining
`ProgressMeter.ncalls(::typeof(mapfun), ::Function, args...)`.
"""
function ncalls(::typeof(broadcast), map_args)
length(map_args) < 2 && return 1
return prod(length, Broadcast.combine_axes(map_args[2:end]...))
end
ncalls(::typeof(map), ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(map!), ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(foreach), ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(asyncmap), ::Function, args...) = ncalls_map(args...)

function ncalls(::typeof(broadcast!), map_args)
length(map_args) < 2 && return 1
return length(map_args[2])
end
ncalls(::typeof(pmap), ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(pmap), ::Function, ::AbstractWorkerPool, args...) = ncalls_map(args...)

function ncalls(::Union{typeof(mapreduce),typeof(mapfoldl),typeof(mapfoldr)}, map_args)
length(map_args) < 3 && return 1
return minimum(length, map_args[3:end])
ncalls(::typeof(mapfoldl), ::Function, ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(mapfoldr), ::Function, ::Function, args...) = ncalls_map(args...)
ncalls(::typeof(mapreduce), ::Function, ::Function, args...) = ncalls_map(args...)

ncalls(::typeof(broadcast), ::Function, args...) = ncalls_broadcast(args...)
ncalls(::typeof(broadcast!), ::Function, args...) = ncalls_broadcast!(args...)

ncalls(::typeof(foldl), ::Function, arg) = ncalls_reduce(arg)
ncalls(::typeof(foldr), ::Function, arg) = ncalls_reduce(arg)
ncalls(::typeof(reduce), ::Function, arg) = ncalls_reduce(arg)

ncalls_reduce(arg) = length(arg) - 1

function ncalls_broadcast(args...)
length(args) < 1 && return 1
return prod(length, Broadcast.combine_axes(args...))
end

function ncalls(::typeof(pmap), map_args)
if length(map_args) ≥ 2 && map_args[2] isa AbstractWorkerPool
length(map_args) < 3 && return 1
return minimum(length, map_args[3:end])
else
length(map_args) < 2 && return 1
return minimum(length, map_args[2:end])
end
function ncalls_broadcast!(args...)
length(args) < 1 && return 1
return length(args[1])
end

function ncalls(mapfun::Function, map_args)
length(map_args) < 2 && return 1
return minimum(length, map_args[2:end])
function ncalls_map(args...)
length(args) < 1 && return 1
return minimum(length, args)
end

include("deprecated.jl")
Expand Down
90 changes: 52 additions & 38 deletions test/test_map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,28 +54,52 @@ wp = WorkerPool(procs)
println()

# test ncalls
@test ncalls(map, (+, 1:10)) == 10
@test ncalls(pmap, (+, 1:10, 1:100)) == 10
@test ncalls(pmap, (+, wp, 1:10)) == 10
@test ncalls(reduce, (+, 1:10)) == 10
@test ncalls(mapreduce, (+, +, 1:10, (1:10)')) == 10
@test ncalls(mapfoldl, (+, +, 1:10, (1:10)')) == 10
@test ncalls(mapfoldr, (+, +, 1:10, (1:10)')) == 10
@test ncalls(foreach, (+, 1:10)) == 10
@test ncalls(broadcast, (+, 1:10, 1:10)) == 10
@test ncalls(broadcast, (+, 1:8, (1:7)', 1)) == 8*7
@test ncalls(broadcast, (+, 1:3, (1:5)', ones(1,1,2))) == 3*5*2
@test ncalls(broadcast!, (+, zeros(10,8))) == 80
@test ncalls(broadcast!, (+, zeros(10,8,7), 1:10)) == 10*8*7

@test ncalls(map, (time,)) == 1
@test ncalls(foreach, (time,)) == 1
@test ncalls(broadcast, (time,)) == 1
@test ncalls(broadcast!, (time, [1])) == 1
@test ncalls(mapreduce, (time, +)) == 1

@test_throws DimensionMismatch ncalls(broadcast, (+, 1:10, 1:100))
@test_throws DimensionMismatch ncalls(broadcast, (+, 1:100, 1:10))
@test ncalls(map, +, 1:10) == 10
@test ncalls(pmap, +, 1:10, 1:100) == 10
@test ncalls(pmap, +, wp, 1:10) == 10
@test ncalls(foldr, +, 1:10) == 9
@test ncalls(foldl, +, 1:10) == 9
@test ncalls(reduce, +, 1:10) == 9
@test ncalls(mapreduce, +, +, 1:10, (1:10)') == 10
@test ncalls(mapfoldl, +, +, 1:10, (1:10)') == 10
@test ncalls(mapfoldr, +, +, 1:10, (1:10)') == 10
@test ncalls(foreach, +, 1:10) == 10
@test ncalls(broadcast, +, 1:10, 1:10) == 10
@test ncalls(broadcast, +, 1:8, (1:7)', 1) == 8*7
@test ncalls(broadcast, +, 1:3, (1:5)', ones(1,1,2)) == 3*5*2
@test ncalls(broadcast!, +, zeros(10,8,7), 1:10) == 10*8*7

# functions with no args
# map(f) and foreach(f) were removed (#291)
@test ncalls(broadcast, time) == 1
@test ncalls(broadcast!, time, [1]) == 1
@test ncalls(broadcast!, time, zeros(10,8)) == 80
@test ncalls(mapreduce, time, +) == 1

@test_throws DimensionMismatch ncalls(broadcast, +, 1:10, 1:100)
@test_throws DimensionMismatch ncalls(broadcast, +, 1:100, 1:10)

@test_throws MethodError ncalls(map, 1:10, 1:10)
@test_throws MethodError @showprogress map(1:10, 1:10)

# test custom mapfun
mymap(f, x) = map(f, [x ; x])
@test_throws MethodError ncalls(mymap, +, 1:10)
@test_throws MethodError @showprogress mymap(+, 1:10)

ProgressMeter.ncalls(::typeof(mymap), ::Function, args...) = 2*ProgressMeter.ncalls_map(args...)
@test ncalls(mymap, +, 1:10) == 20

println("Testing custom map")
vals = @showprogress mymap(1:10) do x
sleep(0.1)
return x^2
end
@test vals == map(x->x^2, [1:10; 1:10])

println("Testing custom map with kwarg (color red)")
vals = @showprogress color=:red mymap(x->(sleep(0.1); x^2), 1:10)
@test vals == map(x->x^2, [1:10; 1:10])

# @showprogress
vals = @showprogress map(1:10) do x
Expand Down Expand Up @@ -137,9 +161,7 @@ wp = WorkerPool(procs)
return x
end
@test A == repeat(1:10, 1, 8)




# function passed by name
function testfun(x)
return x^2
Expand Down Expand Up @@ -172,7 +194,6 @@ wp = WorkerPool(procs)
@test broadcast(constfun) == @showprogress broadcast(constfun)
#@test mapreduce(constfun, error) == @showprogress mapreduce(constfun, error) # julia 1.2+


# #136: make sure mid progress shows up even without sleep
println("Verify that intermediate progress is displayed:")
@showprogress map(1:100) do i
Expand All @@ -184,41 +205,34 @@ wp = WorkerPool(procs)
vals = @showprogress pmap((x,y)->x*y, 1:10, 2:11)
@test vals == map((x,y)->x*y, 1:10, 2:11)







# Progress args
vals = @showprogress dt=0.1 desc="Computing" pmap(testfun, 1:10)
@test vals == map(testfun, 1:10)



# named vector arg
a = collect(1:10)
vals = @showprogress pmap(x->x^2, a)
@test vals == map(x->x^2, a)



# global variable in do
C = 10
vals = @showprogress pmap(1:10) do x
return C*x
end
@test vals == map(x->C*x, 1:10)



# keyword arguments
vals = @showprogress pmap(x->x^2, 1:100, batch_size=10)
@test vals == map(x->x^2, 1:100)
# with semicolon
vals = @showprogress pmap(x->x^2, 1:100; batch_size=10)
@test vals == map(x->x^2, 1:100)

A = rand(0:999, 7, 11, 13)
vals = @showprogress mapreduce(abs2, +, A; dims=1, init=0)
@test vals == mapreduce(abs2, +, A; dims=1, init=0)
vals = @showprogress mapfoldl(abs2, -, A; init=1)
@test vals == mapfoldl(abs2, -, A; init=1)

# pipes after map
@showprogress map(testfun, 1:10) |> sum |> length
Expand Down
Loading