-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mean overflows when using smaller types (e.g. Float16) #140
Comments
Yes, this is expected given the current implementation. At #22 we fixed overflow with integer types, but we haven't discussed the issue of precision of floating-point types. In the second example you give, the main issue is that @stevengj @StefanKarpinski Do you think we should do something about this? |
Off the top of my head, I don't really see a good way to do an accurate Wouldn't it be reasonable to just have a specialized |
Given that |
One solution that I thought is breaking bigger arrays in smaller chunks. Something like this: new_mean(A::AbstractArray{<:Union{Float16,ComplexF16}}; dims=:) = _mean_small(identity, A, dims)
function _mean_small(f, A::AbstractArray{<:Union{Float16,ComplexF16}}, dims::Dims=:) where Dims
isempty(A) && return sum(f, A, dims=dims)/0
if dims === (:)
n = length(A)
else
n = mapreduce(i -> size(A, i), *, unique(dims); init=1)
end
x1 = f(first(A)) / 1
if n > 10000
chunks = [collect(1:10000:length(A)); length(A)]
A_views = [view(A, chunks[i]:chunks[i+1]) for i in 1:length(chunks)-1]
result = sum.(x -> _mean_promote(x1, f(x)), A_views, dims=dims) ./ (chunks[2:end] - chunks[1:end-1])
return sum(result .* (chunks[2:end] - chunks[1:end-1])./10000) / (length(chunks)-1)
else
result = sum(x -> _mean_promote(x1, f(x)), A, dims=dims)
end
if dims === (:)
return result / n
else
return result ./= n
end
end With this, the following results would be obtained: julia> data=rand(Float16, 10^4)
10000-element Vector{Float16}:
julia> sum(data)
Float16(4.97e3)
julia> mean(data)
Float16(0.4968)
julia> new_mean(data)
Float16(0.4968)
julia> data=rand(Float16, 10^5)
100000-element Vector{Float16}:
julia> sum(data)
Float16(4.992e4)
julia> mean(data)
Float16(0.0)
julia> new_mean(data)
Float16(0.4993)
julia> data=rand(Float16, 10^6)
1000000-element Vector{Float16}:
julia> sum(data)
Inf16
julia> mean(data)
NaN16
julia> new_mean(data)
Float16(0.4998) Doing some performance benchmarks: julia> using BenchmarkTools
julia> @benchmark mean(data) setup=(data=rand(Float16, 10^5))
BenchmarkTools.Trial: 8926 samples with 1 evaluation.
Range (min … max): 392.300 μs … 989.200 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 397.000 μs ┊ GC (median): 0.00%
Time (mean ± σ): 400.816 μs ± 20.073 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▄█ ▇ ▃
██▇█▆█▆▅▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
392 μs Histogram: frequency by time 475 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> @benchmark new_mean(data) setup=(data=rand(Float16, 10^5))
BenchmarkTools.Trial: 9033 samples with 1 evaluation.
Range (min … max): 393.200 μs … 914.700 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 397.800 μs ┊ GC (median): 0.00%
Time (mean ± σ): 400.267 μs ± 16.346 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▅ █▂ ▂
█▅▃████▆▅▆▄▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
393 μs Histogram: frequency by time 445 μs <
Memory estimate: 2.14 KiB, allocs estimate: 29.
@benchmark mean(data) setup=(data=rand(Float16, 10^6))
BenchmarkTools.Trial: 905 samples with 1 evaluation.
Range (min … max): 3.942 ms … 7.279 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 4.096 ms ┊ GC (median): 0.00%
Time (mean ± σ): 4.139 ms ± 236.673 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▄█▅▆▄▆▃▂▂▄▅▃▂▁
██████████████▇█▆▇▅▅▅▃▄▄▅▃▄▃▃▃▂▂▃▂▂▂▂▂▁▂▃▁▂▁▂▂▁▁▁▁▁▂▂▁▂▂▁▁▂ ▄
3.94 ms Histogram: frequency by time 4.95 ms <
Memory estimate: 0 bytes, allocs estimate: 0.
@benchmark new_mean(data) setup=(data=rand(Float16, 10^6))
BenchmarkTools.Trial: 923 samples with 1 evaluation.
Range (min … max): 3.939 ms … 4.956 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 4.027 ms ┊ GC (median): 0.00%
Time (mean ± σ): 4.056 ms ± 110.197 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▃▄▆▄▇█▆▆▇▄▂▂
▃▅████████████▇▇▅▆▄▄▃▄▃▃▃▄▄▄▃▂▃▃▃▂▃▃▂▁▃▃▂▂▂▂▁▂▂▃▂▂▂▂▁▂▂▃▁▂▃ ▄
3.94 ms Histogram: frequency by time 4.48 ms <
Memory estimate: 11.94 KiB, allocs estimate: 29.
It increases quite the memory allocation, but this can be improved by replacing the chunk calculation to some loop I think. |
Because the result of a I feel like we already resolved this argument with #25 and this long discourse thread? Of course, you could also have spurious overflow with julia> mean(floatmax(Float32) * [1,1]) # should be 3.4028235f38
Inf32
julia> mean(floatmax(Float64) * [1,1]) # should be 1.7976931348623157e308
Inf but this seems like a less pressing problem, as it's much less likely to happen unexpectedly than for |
@stevengj Yeah we've had a similar discussion before, but then the argument was that we should accumulate using the type that Anyway you have a point that since the mean of @gabrielpreviato Your approach will help a bit, but it won't fix overflow if the sum of values in a block exceed |
I tried implementing accumulation using a wider type, but there's a problem: for heterogeneous collections, it's almost impossible to compute efficiently the type to which the result should be converted back. For example, with The only solution I can see is to do a separate pass just to compute the type. Maybe that's OK since that will only affect heterogeneous arrays (which are slow anyway). |
Why not implement something that is numerically stable? https://discourse.julialang.org/t/why-are-missing-values-not-ignored-by-default/106756/286 |
See also JuliaLang/julia#52365. That would indeed fix the overflow with |
Since the current mean implementation is calculated by summing all elements and then dividing it by the total number of elements, when working with smaller types (such as Float16) it's pretty easy to fall into an overflow when dealing with bigger arrays, as you can see in the following example:
An easy solution when facing this is using Float32 instead, but I wanted to point out this issue when using Float16.
The text was updated successfully, but these errors were encountered: