-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More accurate streaming mean
method for arbitrary iterators
#52365
Comments
This problem is not unique to julia> sum(1f0 for _ in 1:1e8)
1.6777216f7 That said, there are pathological inputs to most (all?) uncompensated sequential accumulations. It's worth ensuring an algorithm in this flavor doesn't simply introduce different pathologies of equal concern. One unfortunate, although (arguably) not devastating, consequence of an algorithm like this one is the propensity for results like this: julia> streaming_mean([-1,-1,-1,1,1,1])
-2.7755575615628914e-17 |
A. It's intriguing, and confusing, why your version is not deterministic:
The 3rd version that's much faster (but ok?): with a += (x - a) * inv(Float32(i + 1))
B. FYI:
I'm not sure (do you know exactly?) why your example gives 1.6777216f7, except likely since it uses an iterator, which was your point, them broken for this. At first I thought you meant to write what I wrote, and since it's legal for me in that context, likely
That's non-ideal (on 1.10rc1) though on 1.9.4:
[if this should actually ERROR, then why not when wrapped in a sum?] |
Parentheses are necessary to construct generators in some contexts. julia> 1f0 for _ in 1:100 # invalid syntax
ERROR: syntax: extra token "for" after end of expression
Stacktrace:
[1] top-level scope
@ none:1
julia> (1f0 for _ in 1:100) # generator
Base.Generator{UnitRange{Int64}, var"#104#105"}(var"#104#105"(), 1:100)
julia> [1f0 for _ in 1:100] |> typeof # comprehension
Vector{Float32} (alias for Array{Float32, 1}) The original example is non-deterministic because it is a generator that calls The discrepancy in the results of julia> maxintfloat(Float32) # adding 1 to this value has no effect
1.6777216f7
julia> sum(1f0 for _ in 1:1e8) # sequential sum
1.6777216f7
julia> sum([1f0 for _ in 1:1e8]) # pairwise sum
1.0f8 There isn't anything intrinsically precluding a pairwise summation algorithm over general iterators, although it's a bit more complicated to implement. That could be used to improve the For what it's worth, the julia> streaming_mean(Iterators.flatten((Iterators.repeated(0f0, 10^8),Iterators.repeated(1f0, 10^8))))
0.13370937f0 # not 0.5f0 |
Interesting. I don't know much about floating point arithmetic. I would just hope that there would be some way that we could improve the numerical accuracy situation for |
I know the difference, sort of, i.e. downsides of the latter. Is the former always better, and can't it just also be used for generators? The (absolute error and) relative seems large, i.e. 100000000/16777216 = 5.96 x, and I know you chose the example to highlight that, but it seems a consequence of floating point, and both can be defined "right". It might not be that common to sum up that many numbers (that diverge up or down, not converging to smaller value), I at least not the other isn't immune:
People should be in the know, use Float64 if worried (or should Float32 sum use Float64 then convert to Float32 in the end? It's not bad, except on GPUs where it may not be available or slower). You end up with problems same problem with Float64, even Float128 and BigFloat... at least if you sum up sufficiently large and/or many values.
That's good to know, and plausible, though I do not see it documented at: help?> Base.Generator nor do I see a need, good reason, for rand, but if I'm looking at the right place shouldn't it be documented (if/when user-visible consequences). Also since just adding up same values, 1s, it seems the order shouldn't matter. Why is streaming_mean, what is uses, nondeterministic? Because of enumerate presumably, neither documenting (nor that it uses a generator, I suppose it's not because of Iterators.peel, it's docs do not indicate):
That seems to imply a for loop and no rand needed... |
It's nothing in particular about generators, it's just that I put |
mean
method for arbitrary iteratorsmean
method for arbitrary iterators
I've changed the title to "more accurate" rather than "numerically stable". According to the mathematical definition, the current algorithm (naive in-order summation) is numerically stable (the wikipedia article on this subject is truly terrible by the way — it never gives a precise definition) (naive summation is backwards stable). The fact that the error grows with |
mean
method for arbitrary iteratorsmean
method for arbitrary iterators
It may be true that the current algorithm in |
I'm not sure of the relative advantages of each approach in terms of accuracy, but at least in terms of performance the advantage of the current one is clear. Taking the example from the OP: julia> using Statistics, BenchmarkTools
julia> A = rand(Float32, 100_000_000);
julia> itr = (x for x in A);
julia> @btime mean(A);
23.134 ms (1 allocation: 16 bytes)
julia> @btime streaming_mean(A);
463.581 ms (1 allocation: 16 bytes)
julia> @btime mean(itr);
106.404 ms (1 allocation: 16 bytes)
julia> @btime streaming_mean(itr);
474.276 ms (1 allocation: 16 bytes) Of course, returning a wrong solution quickly isn't useful, but we have to make sure the proposed algorithm is strictly and significantly better than the existing one before switching. Also, as noted at JuliaStats/Statistics.jl#140 (comment), for small float types, a good solution would be to accumulate in julia> itr2 = (Float64(x) for x in A);
julia> @btime mean(itr2);
106.298 ms (1 allocation: 16 bytes) (The difficulty is how to handle heterogeneous arrays, as discussed in that issue.) |
I'm no expert on floating point arithmetic, but based on the tradeoffs described by @mikmoore above, it seems like "strictly better" (in terms of numerical accuracy) is too high a bar. Perhaps "more accurate in the large majority of likely use cases" would be a better decision rule. Of course I imagine it's harder to assess whether that rule has been satisfied. |
The proposed algorithm (discussed in this stackoverflow post) does one division per iteration. If you're willing to pay that price, you might as well do Kahan summation, which does extra additions but no division. For example: function streaming_mean_kahan(itr)
s, rest = Iterators.peel(itr)
c = zero(s)
count = 1
for x in rest
y = x - c
t = s + y
c = (t - s) - y
s = t
count += 1
end
return s / count
end (which doesn't handle the empty-iterator case, like the original post). julia> streaming_mean_kahan(itr)
0.4999869f0 on the test above. Even without optimizing it, it's 20% faster than the division-based algorithm on my machine. julia> @btime mean($A); @btime mean($itr); @btime streaming_mean($itr); @btime streaming_mean_kahan($itr);
20.456 ms (0 allocations: 0 bytes)
150.730 ms (0 allocations: 0 bytes)
479.155 ms (0 allocations: 0 bytes)
402.276 ms (0 allocations: 0 bytes) (Both algorithms can easily be further optimized by summing multiple sub-streams at once with SIMD.) |
A simple thing that can increase performance without losing much accuracy is to do Kahan summation "chunked" to every 4 numbers at a time. That is, we sum 4 consecutive numbers naively, then do a Kahan step, then another 4, and so on: function streaming_mean_kahan4(itr)
it = iterate(itr)
it === nothing && return zero(eltype(itr)) / 1
s_, state = it
T = typeof(s_ / 1) # accumulate in the precision of the result
s::T = s_
c = zero(s)
count = 1
while true
it = iterate(itr, state)
it === nothing && return s / count
x1::T, state = it
it = iterate(itr, state)
it === nothing && return (s+x1) / (count+1)
x2::T, state = it
it = iterate(itr, state)
it === nothing && return (s+x1+x2) / (count+2)
x3::T, state = it
it = iterate(itr, state)
it === nothing && return (s+x1+x2+x3) / (count+3)
x4::T, state = it
# apply Kahan summation chunked 4 elements at a time
@fastmath x = x1+x2+x3+x4 # (re-associating this line is allowed)
y = x - c
t = s + y
c = (t - s) - y
s = t
count += 4
end
end On my machine, this is actually slightly faster than our current julia> @btime streaming_mean_kahan4($itr);
126.799 ms (0 allocations: 0 bytes) (vs. julia> streaming_mean_kahan4(itr)
0.50002027f0
|
There's no reason we can't use pairwise reduction on an iterator. As a proof-of-concept (with bad performance and a hideous implementation), see the below. Someone better-versed in iterators should be able to make something more performant and legible. They should probably also use a larger base-case like function pairwisesum(itr)
valstate = iterate(itr)
isnothing(valstate) && error("empty iterator but we need to furnish the correct error message")
val,state = valstate
s = val
lognumtosum = 0
while true
s1state = _pairwisesum(itr,state,lognumtosum) # sum of the next 2^lognumtosum elements, or fewer if we run out early
isnothing(s1state) && return s # iterator was already exhausted
s1,state = s1state::Tuple
s += s1 # sum of the first 2 * 2^lognumtosum elements, or fewer if we run out early
isnothing(state) && return s # iterator is exhausted
lognumtosum += 1
end
end
function _pairwisesum(itr,state,lognumtosum)
if lognumtosum < 1
valstate = iterate(itr,state)
isnothing(valstate) && return nothing
val,state = valstate::Tuple
return val,state
else
s1state = _pairwisesum(itr,state,lognumtosum-true)
isnothing(s1state) && return nothing # failed to get any values
s1,state = s1state::Tuple
isnothing(state) && return s1,nothing # got values but iterator is now exhausted
s2state = _pairwisesum(itr,state,lognumtosum-true)
isnothing(s2state) && return s1,nothing # got no further values from second call
s2,state = s2state::Tuple
s12 = s1 + s2
# NOTE: we split the following return statements to improve type stability
# Without the explicit isnothing check, return Tuple{...,Union{...,Nothing}}
# With the explicit isnothing check, return Union{Tuple{...,...},Tuple{...,Nothing}}
isnothing(state) && return s12,nothing
return s12,state # combine values and return
end
end |
Correction — in order to assess the accuracy, I have to change the iterator to itr = Iterators.take(A, length(A)) so that it is using the same random numbers every time. Then exact = sum(Float64, A) / length(A)
@show mean(A) - exact
@show streaming_mean(itr) - exact
@show streaming_mean_kahan(itr) - exact
@show streaming_mean_kahan4(itr) - exact
@show pairwisesum(itr) / length(A) - exact gives: mean(A) - exact = -1.7564311005635602e-8
streaming_mean(itr) - exact = -1.0656993403412862e-5
streaming_mean_kahan(itr) - exact = -1.7564311005635602e-8
streaming_mean_kahan4(itr) - exact = -1.7564311005635602e-8
pairwisesum(itr) / length(A) - exact = -1.7564311005635602e-8 i.e. the pairwise |
I tried to implement a single-pass pairwise streaming mean that has an enlarged/coarsened base case (switching to iterative mean at 128 elements): const PAIRWISE_BASE = 128
function streaming_mean_pairwise(itr)
it = iterate(itr)
it === nothing && return zero(eltype(itr)) / 1
x, state = it
n = PAIRWISE_BASE
sum, count, state = _sum_count_n(x / 1, 1, itr, state, n - 1)
while state !== nothing
sum, count, state = _streaming_mean_pairwise(sum, count, itr, state, n)
n *= 2
end
return sum / count
end
# base case: accumulate sum and count for at most n elements of itr,
# starting from given iteration state, returning accumulated sum, count, state
# where state is final iteration state, or nothing if end of iterator encountered
function _sum_count_n(sum, count, itr, state, n)
for i = 1:n
it = iterate(itr, state)
it === nothing && return sum, count+i-1, nothing
x, state = it
sum += x
end
return sum, count+n, state
end
# recursive case: given sum of ≈ n elements, return sum + (next n), and iteration state
# (or nothing if no remaining elements).
function _streaming_mean_pairwise(sum, count, itr, state, n)
if n <= PAIRWISE_BASE # base case
s, count, state = _sum_count_n(zero(sum), count, itr, state, n)
else
nhalf = n >> 1
s, count, state = _streaming_mean_pairwise(zero(sum), count, itr, state, n - nhalf)
if state !== nothing
s, count, state = _streaming_mean_pairwise(s, count, itr, state, nhalf)
end
end
# note: split return statement to improve type stability
return isnothing(state) ? (sum + s, count, nothing) : (sum + s, count, state)
end
julia> @btime streaming_mean_pairwise($itr);
117.636 ms (0 allocations: 0 bytes) It seems to be very slightly less accurate than Kahan or the (very slow) julia> streaming_mean_pairwise(itr) - exact
4.204033376975502e-8 I'm not sure it's worth it here — 4-chunked Kahan a much simpler algorithm in comparison, has very decent perfomance, and should be at least as accurate as pairwise summation with a large base case. |
In comparison, if we simply apply Kahan summation to chunks of 128 elements, the same as the pairwise base case, then it is both more accurate and faster than the pairwise algorithm: function streaming_mean_kahan128(itr)
it = iterate(itr)
it === nothing && return zero(eltype(itr)) / 1
s_, state = it
T = typeof(s_ / 1)
s::T = s_
c = zero(s)
count = 1
while state !== nothing
# naive summation of a chunk of 128 elements
x, count, state = _sum_count_n(zero(s), count, itr, state, 128)
# apply Kahan summation to this chunk of the sum
y = x - c
t = s + y
c = (t - s) - y
s = t
end
return s / count
end gives julia> streaming_mean_kahan128(itr) - exact
-1.7564311005635602e-8
julia> @btime streaming_mean_kahan128($itr)
100.566 ms (0 allocations: 0 bytes) This makes me think that we should also ditch the pairwise algorithm for |
I definitely agree that it'd be ideal if the same rough algorithm was used by iterators and As for "chunked Kahan" style algorithms, I think there's an issue. I definitely could be wrong on this so forgive me if I am. I think that it basically defeats the purpose of Kahan in the first place. Catastrophic errors are possible from from adding just 3 (2?) numbers by naive means. So adding many numbers naively and then doing the compensation would have irrecoverably forfeited the protection Kahan usually provides. For example, it appears to me that the compensation in |
The Kahan cannot save the situation when there is less precision:
gives
|
It's exactly the same principle as doing pairwise summation with a coarsened base case to improve performance. It's equivalent to:
Not only is this strictly better than naive summation of the whole array, it also tends to get the full benefit of Kahan in typical cases of large data sets. The reason is that, for large datasets, usually the overall sum is much larger than the sum of any chunk of 128 elements, so worsened roundoff errors from the naive chunk sums are insignificant (round to zero) relative to the overall sum. Of course, there are ways to defeat this — if the sum is dominated by a few elements that appear in a single chunk and have a large cancellation error within that chunk, for example. But the same cases will also defeat pairwise summation with a large base case. The point here is not to be perfect in all cases (for that you need something like Xsum.jl), but rather to avoid poor asymptotic scaling of the error for large n (n = length of data). Naive summation has O(sqrt(n)) rms error growth. Pairwise summation has O(sqrt(log(n)) growth, and Kahan has O(1) growth … and these growth rates are independent of the size of the base case (which may affect the constant factor, however). |
It can get pretty bad even in double precision:
gives
|
Overflow is a separate issue. To avoid that you need to scale intermediate results. The situation could be improved without too big of a cost tradeoff by, e.g. pairwise summation where you compute the mean in chunks rather than the sum in chunks.
No, you are making the same mistake I did — you need to generate the same random numbers each time. If you do it correctly: using Xsum
B = rand(Float64, 100_000_000);
Bitr = Iterators.take(B, length(B))
@show exact = xsum(B) / length(B)
@show mean(B) - exact
@show mean(Bitr) - exact
@show streaming_mean(Bitr) - exact
@show streaming_mean_kahan(Bitr) - exact
@show streaming_mean_kahan4(Bitr) - exact
@show streaming_mean_kahan128(Bitr) - exact
@show streaming_mean_pairwise(Bitr) - exact you get very good results from the Kahan (chunked or not) or pairwise methods: exact = xsum(B) / length(B) = 0.4999865615735586
mean(B) - exact = 0.0
mean(Bitr) - exact = -3.2029934260435766e-14
streaming_mean(Bitr) - exact = 1.708633234898116e-13
streaming_mean_kahan(Bitr) - exact = 0.0
streaming_mean_kahan4(Bitr) - exact = 1.1102230246251565e-16
streaming_mean_kahan128(Bitr) - exact = 0.0
streaming_mean_pairwise(Bitr) - exact = 0.0 |
Here is a pairwise variant that computes the mean in chunks, rather than the sum in chunks, so that it avoids most spurious overflow for # uses PAIRWISE_BASE and _sum_count_n definitions from previous post
function streaming_mean_pairwise2(itr)
it = iterate(itr)
it === nothing && return zero(eltype(itr)) / 1
x, state = it
n = PAIRWISE_BASE
sum, count, state = _sum_count_n(x / 1, 1, itr, state, n - 1)
mean = sum / count
while state !== nothing
mean, count, state = _streaming_mean_pairwise2(mean, count, itr, state, n)
n *= 2
end
return mean
end
# work around julia issue #52394
Base.Float16(x::Rational{<:Union{Int16,Int32,Int64,Int128,UInt16,UInt32,UInt64,UInt128}}) = Float16(Float32(x))
function _streaming_mean_pairwise2(mean, count, itr, state, n)
if n <= PAIRWISE_BASE # base case
s, c, state = _sum_count_n(zero(mean), 0, itr, state, n)
iszero(c) && return mean, count, nothing
m = s / c
else
nhalf = n >> 1
m, c, state = _streaming_mean_pairwise2(zero(mean), 0, itr, state, n - nhalf)
iszero(c) && return mean, count, nothing
if state !== nothing
m, c, state = _streaming_mean_pairwise2(m, c, itr, state, nhalf)
end
end
newcount = count + c
# arranged carefully to reduce spurious overflow, using rationals to avoid promoting to Float64:
mean += (m - mean) * (c // newcount)
return isnothing(state) ? (mean, newcount, nothing) : (mean, newcount, state)
end Test: A16 = rand(Float16, 100_000_000);
itr16 = Iterators.take(A16, length(A16));
@show exact16 = sum(Float64, A16) / length(A16);
@show mean(A16) - exact16;
@show mean(itr16) - exact16;
@show streaming_mean(itr16) - exact16;
@show streaming_mean_kahan(itr16) - exact16;
@show streaming_mean_pairwise(itr16) - exact16;
@show streaming_mean_pairwise2(itr16) - exact16; which gives: exact16 = sum(Float64, A16) / length(A16) = 0.49976733576660154
mean(A16) - exact16 = NaN
mean(itr16) - exact16 = -0.49976733576660154
streaming_mean(itr16) - exact16 = 0.00902172673339846
streaming_mean_kahan(itr16) - exact16 = NaN
streaming_mean_pairwise(itr16) - exact16 = NaN
streaming_mean_pairwise2(itr16) - exact16 = -0.00025561701660153924 (Of course, if we know that the For A32 = rand(Float32, 100_000_000);
itr32 = Iterators.take(A32, length(A32))
@btime mean($A32);
@btime mean($itr32);
@btime streaming_mean($itr32);
@btime streaming_mean_kahan($itr32);
@btime streaming_mean_kahan4($itr32);
@btime streaming_mean_kahan128($itr32);
@btime streaming_mean_pairwise($itr32);
@btime streaming_mean_pairwise2($itr32); gives
I'm not sure of a good way to avoid overflow in Kahan or blocked Kahan; it really seems to need recursive rescaling. |
I'm not sure way, I thought you meant because I tried to look up latency of division (and e.g. reciprocal) but it's hard to find (and differs by CPUs likely), but at least I know it's faster for 32-bit assembly division instruction rather then 64-bit. I've exploited that in a benchmark, and combining division (I see I independently discovered, i.e. the trick documented here at Agner): https://www.agner.org/optimize/optimizing_cpp.pdf
you can do (32-bit) reciprocal with vdivps, and "low accuracy" (presumably faster) with "vrcp14ps + NR" (such seems not available for 64-bit, except of discontinued Knights Landing, maybe with discontinued AVX-512ER). Julia actually does neither for me on my computer, uses vdivss. At least on AMD:
https://colfaxresearch.com/skl-avx512/
Intriguing new for "Scalable family (formerly Skylake)":
[I think Knights Landing is discontinued.]
https://stackoverflow.com/questions/4125033/floating-point-division-vs-floating-point-multiplication
Found while looking up and it seems to have info in those "arith.divider_active hardware":
|
See #52397. |
@nalimilan, yes, with #52397 we could in principle re-implement |
Is there a reason that we can't have a more numerically stable streaming
mean
method for arbitrary iterators? Here's a sketch of an implementation:And here's a demonstration of the numerical performance:
(I would imagine that there are some tweaks to be made to optimize the performance of
streaming_mean
, but I hope by now we've learned that correctness (i.e. numerical accuracy) is more important than performance.)The text was updated successfully, but these errors were encountered: