Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More accurate streaming mean method for arbitrary iterators #52365

Open
CameronBieganek opened this issue Dec 1, 2023 · 25 comments
Open

More accurate streaming mean method for arbitrary iterators #52365

CameronBieganek opened this issue Dec 1, 2023 · 25 comments

Comments

@CameronBieganek
Copy link
Contributor

CameronBieganek commented Dec 1, 2023

Is there a reason that we can't have a more numerically stable streaming mean method for arbitrary iterators? Here's a sketch of an implementation:

function streaming_mean(itr)
    a, rest = Iterators.peel(itr)
    for (i, x) in enumerate(rest)
        a += (x - a) / (i + 1)
    end
    return a
end

And here's a demonstration of the numerical performance:

julia> using Statistics

julia> A = rand(Float32, 100_000_000);

julia> itr = (rand(Float32) for _ in 1:100_000_000);

julia> mean(A), mean(itr), streaming_mean(itr)
(0.4999884f0, 0.16777216f0, 0.49996638f0)

(I would imagine that there are some tweaks to be made to optimize the performance of streaming_mean, but I hope by now we've learned that correctness (i.e. numerical accuracy) is more important than performance.)

@mikmoore
Copy link
Contributor

mikmoore commented Dec 1, 2023

This problem is not unique to Statistics.mean. Here's a simpler and more widespread issue:

julia> sum(1f0 for _ in 1:1e8)
1.6777216f7

That said, there are pathological inputs to most (all?) uncompensated sequential accumulations. It's worth ensuring an algorithm in this flavor doesn't simply introduce different pathologies of equal concern.

One unfortunate, although (arguably) not devastating, consequence of an algorithm like this one is the propensity for results like this:

julia> streaming_mean([-1,-1,-1,1,1,1])
-2.7755575615628914e-17

@PallHaraldsson
Copy link
Contributor

PallHaraldsson commented Dec 1, 2023

A. It's intriguing, and confusing, why your version is not deterministic:

julia> @time streaming_mean(itr)
  0.769110 seconds (1 allocation: 16 bytes)
0.5000117f0

julia> @time streaming_mean(itr)
  0.699333 seconds (1 allocation: 16 bytes)
0.4999693f0

my versions also:
julia> @time streaming_mean3(itr)
  0.369009 seconds (1 allocation: 16 bytes)
0.50004f0


julia> @time streaming_mean2(itr)
  0.666329 seconds (1 allocation: 16 bytes)
0.5000293f0

julia> @time streaming_mean2(itr)
  0.707315 seconds (1 allocation: 16 bytes)
0.49995688f0

julia> function streaming_mean2(itr)
           a, rest = Iterators.peel(itr)
           for (i, x) in enumerate(rest)
               a += (x - a) / Float32(i + 1)
           end
           return a
       end

The 3rd version that's much faster (but ok?): with a += (x - a) * inv(Float32(i + 1))

I can't add @simd, is there a workaround:
ERROR: LoadError: Base.SimdLoop.SimdError("simd loop index must be a symbol")

julia> @time mean(A)  # apparently never in-deterministic:
  0.049539 seconds (1 allocation: 16 bytes)
0.5000132f0

B. FYI:

julia> sum([1f0 for _ in 1:1e8])  # correct, likely your point
1.0f8

I'm not sure (do you know exactly?) why your example gives 1.6777216f7, except likely since it uses an iterator, which was your point, them broken for this. At first I thought you meant to write what I wrote, and since it's legal for me in that context, likely should be here too:

julia> 1f0 for _ in 1:10
┌ Error: JuliaSyntax parser failed — falling back to flisp!
│   exception =
│    BoundsError: attempt to access 0-element Vector{Any} at index [1]

That's non-ideal (on 1.10rc1) though on 1.9.4:

julia> 1f0 for _ in 1:10
ERROR: syntax: extra token "for" after end of expression

[if this should actually ERROR, then why not when wrapped in a sum?]

@mikmoore
Copy link
Contributor

mikmoore commented Dec 1, 2023

Parentheses are necessary to construct generators in some contexts.

julia> 1f0 for _ in 1:100 # invalid syntax
ERROR: syntax: extra token "for" after end of expression
Stacktrace:
 [1] top-level scope
   @ none:1

julia> (1f0 for _ in 1:100) # generator
Base.Generator{UnitRange{Int64}, var"#104#105"}(var"#104#105"(), 1:100)

julia> [1f0 for _ in 1:100] |> typeof # comprehension
Vector{Float32} (alias for Array{Float32, 1})

The original example is non-deterministic because it is a generator that calls rand to produce values. Adding brackets makes it a comprehension, which is eagerly evaluated and so will produce a consistent value because the random values are only sampled once.

The discrepancy in the results of sum comes from the type of iterable. sum uses pairwise summation over AbstractArray arguments, but sequential summation over general iterators. In this case, the sequential sum eventually saturates due to limited precision.

julia> maxintfloat(Float32) # adding 1 to this value has no effect
1.6777216f7

julia> sum(1f0 for _ in 1:1e8) # sequential sum
1.6777216f7

julia> sum([1f0 for _ in 1:1e8]) # pairwise sum
1.0f8

There isn't anything intrinsically precluding a pairwise summation algorithm over general iterators, although it's a bit more complicated to implement. That could be used to improve the sum algorithm, although it'd be an outstanding question whether we wanted to use the same for mean.

For what it's worth, the streaming_mean algorithm suffers more-or-less the same issue as sequential summation. Eventually, further accumulation will be no-ops because a will be much greater than abs(x-a) / (i + 1) (since x-a is bounded for bounded inputs and i + 1 grows indefinitely. For example:

julia> streaming_mean(Iterators.flatten((Iterators.repeated(0f0, 10^8),Iterators.repeated(1f0, 10^8))))
0.13370937f0 # not 0.5f0

@CameronBieganek
Copy link
Contributor Author

Interesting. I don't know much about floating point arithmetic. I would just hope that there would be some way that we could improve the numerical accuracy situation for mean(itr) and sum(itr). 🙂

@PallHaraldsson
Copy link
Contributor

sum uses pairwise summation over AbstractArray arguments, but sequential summation over general iterators.

I know the difference, sort of, i.e. downsides of the latter. Is the former always better, and can't it just also be used for generators?

The (absolute error and) relative seems large, i.e. 100000000/16777216 = 5.96 x, and I know you chose the example to highlight that, but it seems a consequence of floating point, and both can be defined "right". It might not be that common to sum up that many numbers (that diverge up or down, not converging to smaller value), I at least not the other isn't immune:

julia> sum([100f0 for _ in 1:1e9])
1.0000001f11

julia> sum([1000f0 for _ in 1:1e9])
9.999999f11

julia> sum([1f0 for _ in 1:1e10])  #  I had to change the example from summing more vlaues, since I think OOM:
Killed

People should be in the know, use Float64 if worried (or should Float32 sum use Float64 then convert to Float32 in the end? It's not bad, except on GPUs where it may not be available or slower). You end up with problems same problem with Float64, even Float128 and BigFloat... at least if you sum up sufficiently large and/or many values.

is non-deterministic because it is a generator that calls rand to produce values.

That's good to know, and plausible, though I do not see it documented at:

help?> Base.Generator

nor do I see a need, good reason, for rand, but if I'm looking at the right place shouldn't it be documented (if/when user-visible consequences). Also since just adding up same values, 1s, it seems the order shouldn't matter. Why is streaming_mean, what is uses, nondeterministic? Because of enumerate presumably, neither documenting (nor that it uses a generator, I suppose it's not because of Iterators.peel, it's docs do not indicate):

An iterator that yields (i, x) where i is a counter starting at 1, and x is the ith value from the given iterator.

That seems to imply a for loop and no rand needed...

@CameronBieganek
Copy link
Contributor Author

is non-deterministic because it is a generator that calls rand to produce values.

That's good to know, and plausible, though I do not see it documented at:

help?> Base.Generator

nor do I see a need, good reason, for rand, but if I'm looking at the right place shouldn't it be documented (if/when user-visible consequences).

It's nothing in particular about generators, it's just that I put rand() inside the generator, so every time you iterate itr, it generates a new sequence of random numbers. It's not that much different from repeatedly calling rand(100): you get a different sequence of random numbers every time.

@stevengj stevengj changed the title Numerically stable streaming mean method for arbitrary iterators more accurate streaming mean method for arbitrary iterators Dec 2, 2023
@stevengj
Copy link
Member

stevengj commented Dec 2, 2023

I've changed the title to "more accurate" rather than "numerically stable". According to the mathematical definition, the current algorithm (naive in-order summation) is numerically stable (the wikipedia article on this subject is truly terrible by the way — it never gives a precise definition) (naive summation is backwards stable). The fact that the error grows with n faster than you would like is not the same thing as instability.

@CameronBieganek CameronBieganek changed the title more accurate streaming mean method for arbitrary iterators More accurate streaming mean method for arbitrary iterators Dec 2, 2023
@PetrKryslUCSD
Copy link

I've changed the title to "more accurate" rather than "numerically stable". According to the mathematical definition, the current algorithm (naive in-order summation) is numerically stable (the wikipedia article on this subject is truly terrible by the way — it never gives a precise definition) (naive summation is backwards stable). The fact that the error grows with n faster than you would like is not the same thing as instability.

It may be true that the current algorithm in mean may be numerically stable, but the proposed algorithm would appear certainly more robust. Is there any reason for not using that algorithm?

@nalimilan
Copy link
Member

I'm not sure of the relative advantages of each approach in terms of accuracy, but at least in terms of performance the advantage of the current one is clear. Taking the example from the OP:

julia> using Statistics, BenchmarkTools

julia> A = rand(Float32, 100_000_000);

julia> itr = (x for x in A);

julia> @btime mean(A);
  23.134 ms (1 allocation: 16 bytes)

julia> @btime streaming_mean(A);
  463.581 ms (1 allocation: 16 bytes)

julia> @btime mean(itr);
  106.404 ms (1 allocation: 16 bytes)

julia> @btime streaming_mean(itr);
  474.276 ms (1 allocation: 16 bytes)

Of course, returning a wrong solution quickly isn't useful, but we have to make sure the proposed algorithm is strictly and significantly better than the existing one before switching.

Also, as noted at JuliaStats/Statistics.jl#140 (comment), for small float types, a good solution would be to accumulate in Float64, which is both accurate and fast:

julia> itr2 = (Float64(x) for x in A);

julia> @btime mean(itr2);
  106.298 ms (1 allocation: 16 bytes)

(The difficulty is how to handle heterogeneous arrays, as discussed in that issue.)

@CameronBieganek
Copy link
Contributor Author

we have to make sure the proposed algorithm is strictly and significantly better than the existing one before switching

I'm no expert on floating point arithmetic, but based on the tradeoffs described by @mikmoore above, it seems like "strictly better" (in terms of numerical accuracy) is too high a bar. Perhaps "more accurate in the large majority of likely use cases" would be a better decision rule. Of course I imagine it's harder to assess whether that rule has been satisfied.

@stevengj
Copy link
Member

stevengj commented Dec 4, 2023

The proposed algorithm (discussed in this stackoverflow post) does one division per iteration. If you're willing to pay that price, you might as well do Kahan summation, which does extra additions but no division.

For example:

function streaming_mean_kahan(itr)
    s, rest = Iterators.peel(itr)
    c = zero(s)
    count = 1
    for x in rest
        y = x - c
        t = s + y
        c = (t - s) - y
        s = t
        count += 1
    end
    return s / count
end

(which doesn't handle the empty-iterator case, like the original post).
On the test above, this gives

julia> streaming_mean_kahan(itr)
0.4999869f0

on the test above. Even without optimizing it, it's 20% faster than the division-based algorithm on my machine.

julia> @btime mean($A); @btime mean($itr); @btime streaming_mean($itr); @btime streaming_mean_kahan($itr);
  20.456 ms (0 allocations: 0 bytes)
  150.730 ms (0 allocations: 0 bytes)
  479.155 ms (0 allocations: 0 bytes)
  402.276 ms (0 allocations: 0 bytes)

(Both algorithms can easily be further optimized by summing multiple sub-streams at once with SIMD.)

@stevengj
Copy link
Member

stevengj commented Dec 4, 2023

A simple thing that can increase performance without losing much accuracy is to do Kahan summation "chunked" to every 4 numbers at a time. That is, we sum 4 consecutive numbers naively, then do a Kahan step, then another 4, and so on:

function streaming_mean_kahan4(itr)
    it = iterate(itr)
    it === nothing && return zero(eltype(itr)) / 1
    s_, state = it
    T = typeof(s_ / 1) # accumulate in the precision of the result
    s::T = s_
    c = zero(s)
    count = 1  
    while true
        it = iterate(itr, state)
        it === nothing && return s / count
        x1::T, state = it
        it = iterate(itr, state)
        it === nothing && return (s+x1) / (count+1)
        x2::T, state = it
        it = iterate(itr, state)
        it === nothing && return (s+x1+x2) / (count+2)
        x3::T, state = it
        it = iterate(itr, state)
        it === nothing && return (s+x1+x2+x3) / (count+3)
        x4::T, state = it
        
        # apply Kahan summation chunked 4 elements at a time
        @fastmath x = x1+x2+x3+x4 # (re-associating this line is allowed)
        y = x - c
        t = s + y
        c = (t - s) - y
        s = t
        count += 4
    end
end

On my machine, this is actually slightly faster than our current mean(itr):

julia> @btime streaming_mean_kahan4($itr);
126.799 ms (0 allocations: 0 bytes)

(vs. 150.730 ms for mean(itr), above), while still attaining excellent accuracy:

julia> streaming_mean_kahan4(itr)
0.50002027f0

Here, compared to the exact result, the most accurate method on this test data (on my machine) is the pairwise mean(A), followed by streaming_kahan(itr), followed by streaming_kahan4(itr), followed by the proposed streaming_mean(itr): Correction: see below

@mikmoore
Copy link
Contributor

mikmoore commented Dec 4, 2023

There's no reason we can't use pairwise reduction on an iterator. As a proof-of-concept (with bad performance and a hideous implementation), see the below. Someone better-versed in iterators should be able to make something more performant and legible. They should probably also use a larger base-case like Array reduction already does.

function pairwisesum(itr)
	valstate = iterate(itr)
	isnothing(valstate) && error("empty iterator but we need to furnish the correct error message")
	val,state = valstate
	s = val
	lognumtosum = 0
	while true
		s1state = _pairwisesum(itr,state,lognumtosum) # sum of the next 2^lognumtosum elements, or fewer if we run out early
		isnothing(s1state) && return s # iterator was already exhausted
		s1,state = s1state::Tuple
		s += s1 # sum of the first 2 * 2^lognumtosum elements, or fewer if we run out early
		isnothing(state) && return s # iterator is exhausted
		lognumtosum += 1
	end
end

function _pairwisesum(itr,state,lognumtosum)
	if lognumtosum < 1
		valstate = iterate(itr,state)
		isnothing(valstate) && return nothing
		val,state = valstate::Tuple
		return val,state
	else
		s1state = _pairwisesum(itr,state,lognumtosum-true)
		isnothing(s1state) && return nothing # failed to get any values
		s1,state = s1state::Tuple
		isnothing(state) && return s1,nothing # got values but iterator is now exhausted
		s2state = _pairwisesum(itr,state,lognumtosum-true)
		isnothing(s2state) && return s1,nothing # got no further values from second call
		s2,state = s2state::Tuple
		s12 = s1 + s2
		# NOTE: we split the following return statements to improve type stability
		#   Without the explicit isnothing check, return Tuple{...,Union{...,Nothing}}
		#   With the explicit isnothing check, return Union{Tuple{...,...},Tuple{...,Nothing}}
		isnothing(state) && return s12,nothing
		return s12,state # combine values and return
	end
end

@stevengj
Copy link
Member

stevengj commented Dec 4, 2023

Correction — in order to assess the accuracy, I have to change the iterator to

itr = Iterators.take(A, length(A))

so that it is using the same random numbers every time. Then

exact = sum(Float64, A) / length(A)
@show mean(A) - exact
@show streaming_mean(itr) - exact
@show streaming_mean_kahan(itr) - exact
@show streaming_mean_kahan4(itr) - exact
@show pairwisesum(itr) / length(A) - exact

gives:

mean(A) - exact = -1.7564311005635602e-8
streaming_mean(itr) - exact = -1.0656993403412862e-5
streaming_mean_kahan(itr) - exact = -1.7564311005635602e-8
streaming_mean_kahan4(itr) - exact = -1.7564311005635602e-8
pairwisesum(itr) / length(A) - exact = -1.7564311005635602e-8

i.e. the pairwise mean, 4-element Kahan, Kahan, and @milkmoore's pairwise sum all give the same results, whereas @CameronBieganek's streaming_mean is 1000x worse.

@stevengj
Copy link
Member

stevengj commented Dec 4, 2023

I tried to implement a single-pass pairwise streaming mean that has an enlarged/coarsened base case (switching to iterative mean at 128 elements):

const PAIRWISE_BASE = 128

function streaming_mean_pairwise(itr)
    it = iterate(itr)
    it === nothing && return zero(eltype(itr)) / 1
    x, state = it
    n = PAIRWISE_BASE
    sum, count, state = _sum_count_n(x / 1, 1, itr, state, n - 1)
    while state !== nothing
        sum, count, state = _streaming_mean_pairwise(sum, count, itr, state, n)
        n *= 2
    end
    return sum / count
end

# base case: accumulate sum and count for at most n elements of itr,
# starting from given iteration state, returning accumulated sum, count, state
# where state is final iteration state, or nothing if end of iterator encountered
function _sum_count_n(sum, count, itr, state, n)
    for i = 1:n
        it = iterate(itr, state)
        it === nothing && return sum, count+i-1, nothing
        x, state = it
        sum += x
    end
    return sum, count+n, state
end

# recursive case: given sum of ≈ n elements, return sum + (next n), and iteration state
# (or nothing if no remaining elements).
function _streaming_mean_pairwise(sum, count, itr, state, n)
    if n <= PAIRWISE_BASE # base case
        s, count, state = _sum_count_n(zero(sum), count, itr, state, n)
    else
        nhalf = n >> 1
        s, count, state = _streaming_mean_pairwise(zero(sum), count, itr, state, n - nhalf)
        if state !== nothing
            s, count, state = _streaming_mean_pairwise(s, count, itr, state, nhalf)
        end
    end
    # note: split return statement to improve type stability
    return isnothing(state) ? (sum + s, count, nothing) : (sum + s, count, state)
end

but the performance right now is not good due to compiler issues — there are zillions of allocations. There is the usual two-fold type instability of iterators (e.g. state is either a tuple or nothing), but apparently Julia is not optimizing that away, maybe because it crosses function boundaries. After splitting the final return statement, performance is now very good:

julia> @btime streaming_mean_pairwise($itr);
117.636 ms (0 allocations: 0 bytes)

It seems to be very slightly less accurate than Kahan or the (very slow) pairwisesum (with basecase 1) above on this particular dataset:

julia> streaming_mean_pairwise(itr) - exact
4.204033376975502e-8

I'm not sure it's worth it here — 4-chunked Kahan a much simpler algorithm in comparison, has very decent perfomance, and should be at least as accurate as pairwise summation with a large base case.

@stevengj
Copy link
Member

stevengj commented Dec 4, 2023

In comparison, if we simply apply Kahan summation to chunks of 128 elements, the same as the pairwise base case, then it is both more accurate and faster than the pairwise algorithm:

function streaming_mean_kahan128(itr)
    it = iterate(itr)
    it === nothing && return zero(eltype(itr)) / 1
    s_, state = it
    T = typeof(s_ / 1)
    s::T = s_
    c = zero(s)
    count = 1  
    while state !== nothing
        # naive summation of a chunk of 128 elements
        x, count, state = _sum_count_n(zero(s), count, itr, state, 128)
        
        # apply Kahan summation to this chunk of the sum
        y = x - c
        t = s + y
        c = (t - s) - y
        s = t
    end
    return s / count
end

gives

julia> streaming_mean_kahan128(itr) - exact
-1.7564311005635602e-8

julia> @btime streaming_mean_kahan128($itr)
  100.566 ms (0 allocations: 0 bytes)

This makes me think that we should also ditch the pairwise algorithm for sum in favor of chunked Kahan.

@mikmoore
Copy link
Contributor

mikmoore commented Dec 4, 2023

I definitely agree that it'd be ideal if the same rough algorithm was used by iterators and AbstractArray (at least for total-array reductions -- probably not for dimensional). That doesn't seem intrinsically impossible, although SIMD may motivate slight differences.

As for "chunked Kahan" style algorithms, I think there's an issue. I definitely could be wrong on this so forgive me if I am. I think that it basically defeats the purpose of Kahan in the first place. Catastrophic errors are possible from from adding just 3 (2?) numbers by naive means. So adding many numbers naively and then doing the compensation would have irrecoverably forfeited the protection Kahan usually provides. For example, it appears to me that the compensation in streaming_mean_kahan128 is a no-op on inputs shorter than 129 so it can't be doing a valid flavor of Kahan summation.

@PetrKryslUCSD
Copy link

The Kahan cannot save the situation when there is less precision:

using Statistics

A = rand(Float16, 100_000_000);

itr = (rand(Float16) for _ in 1:100_000_000);

function streaming_mean_kahan(itr)
    s, rest = Iterators.peel(itr)
    c = zero(s)
    count = 1
    for x in rest
        y = x - c
        t = s + y
        c = (t - s) - y
        s = t
        count += 1
    end
    return s / count
end

function streaming_mean(itr)
    a, rest = Iterators.peel(itr)
    for (i, x) in enumerate(rest)
        a += (x - a) / (i + 1)
    end
    return a
end

@show mean(A), streaming_mean_kahan(A), streaming_mean(A)

gives

(mean(A), streaming_mean_kahan(A), streaming_mean(A)) = (NaN, NaN, 4.99756e-01)
(NaN, NaN, 4.99756e-01)

@stevengj
Copy link
Member

stevengj commented Dec 4, 2023

is a no-op on inputs shorter than 129 so it can't be doing a valid flavor of Kahan summation.

It's exactly the same principle as doing pairwise summation with a coarsened base case to improve performance. It's equivalent to:

  1. Sum every "chunk" of 128 elements naively, to produce a new array of length length(itr) / 128.
  2. Sum the new array using Kahan summation

Not only is this strictly better than naive summation of the whole array, it also tends to get the full benefit of Kahan in typical cases of large data sets. The reason is that, for large datasets, usually the overall sum is much larger than the sum of any chunk of 128 elements, so worsened roundoff errors from the naive chunk sums are insignificant (round to zero) relative to the overall sum.

Of course, there are ways to defeat this — if the sum is dominated by a few elements that appear in a single chunk and have a large cancellation error within that chunk, for example. But the same cases will also defeat pairwise summation with a large base case.

The point here is not to be perfect in all cases (for that you need something like Xsum.jl), but rather to avoid poor asymptotic scaling of the error for large n (n = length of data). Naive summation has O(sqrt(n)) rms error growth. Pairwise summation has O(sqrt(log(n)) growth, and Kahan has O(1) growth … and these growth rates are independent of the size of the base case (which may affect the constant factor, however).

@PetrKryslUCSD
Copy link

PetrKryslUCSD commented Dec 4, 2023

It can get pretty bad even in double precision:

A = (rand(Float64) for _ in 1:100_000_000);
@show exact = sum(Float64, A) / length(A)
@show mean(A) - exact
@show streaming_mean(A) - exact
@show streaming_mean_kahan(A) - exact

gives

exact = sum(Float64, A) / length(A) = 4.99995e-01
mean(A) - exact = 3.70823e-05
streaming_mean(A) - exact = 6.57060e-06
streaming_mean_kahan(A) - exact = -1.89430e-05

@stevengj
Copy link
Member

stevengj commented Dec 4, 2023

The Kahan cannot save the situation when there is less precision:

Overflow is a separate issue. To avoid that you need to scale intermediate results. The situation could be improved without too big of a cost tradeoff by, e.g. pairwise summation where you compute the mean in chunks rather than the sum in chunks.

It can get pretty bad even in double precision:

No, you are making the same mistake I did — you need to generate the same random numbers each time. If you do it correctly:

using Xsum
B = rand(Float64, 100_000_000);
Bitr = Iterators.take(B, length(B))
@show exact = xsum(B) / length(B)
@show mean(B) - exact
@show mean(Bitr) - exact
@show streaming_mean(Bitr) - exact
@show streaming_mean_kahan(Bitr) - exact
@show streaming_mean_kahan4(Bitr) - exact
@show streaming_mean_kahan128(Bitr) - exact
@show streaming_mean_pairwise(Bitr) - exact

you get very good results from the Kahan (chunked or not) or pairwise methods:

exact = xsum(B) / length(B) = 0.4999865615735586
mean(B) - exact = 0.0
mean(Bitr) - exact = -3.2029934260435766e-14
streaming_mean(Bitr) - exact = 1.708633234898116e-13
streaming_mean_kahan(Bitr) - exact = 0.0
streaming_mean_kahan4(Bitr) - exact = 1.1102230246251565e-16
streaming_mean_kahan128(Bitr) - exact = 0.0
streaming_mean_pairwise(Bitr) - exact = 0.0

@stevengj
Copy link
Member

stevengj commented Dec 4, 2023

Here is a pairwise variant that computes the mean in chunks, rather than the sum in chunks, so that it avoids most spurious overflow for Float16. (Again, it bottoms out at a naive sum for 128 elements, so this only helps with the scaling for large collections, not if a few individual elements overflow when summed.)

# uses PAIRWISE_BASE and _sum_count_n definitions from previous post

function streaming_mean_pairwise2(itr)
    it = iterate(itr)
    it === nothing && return zero(eltype(itr)) / 1
    x, state = it
    n = PAIRWISE_BASE
    sum, count, state = _sum_count_n(x / 1, 1, itr, state, n - 1)
    mean = sum / count
    while state !== nothing
        mean, count, state = _streaming_mean_pairwise2(mean, count, itr, state, n)
        n *= 2
    end
    return mean
end

# work around julia issue #52394
Base.Float16(x::Rational{<:Union{Int16,Int32,Int64,Int128,UInt16,UInt32,UInt64,UInt128}}) = Float16(Float32(x))

function _streaming_mean_pairwise2(mean, count, itr, state, n)
    if n <= PAIRWISE_BASE # base case
        s, c, state = _sum_count_n(zero(mean), 0, itr, state, n)
        iszero(c) && return mean, count, nothing
        m = s / c
    else
        nhalf = n >> 1
        m, c, state = _streaming_mean_pairwise2(zero(mean), 0, itr, state, n - nhalf)
        iszero(c) && return mean, count, nothing
        if state !== nothing
            m, c, state = _streaming_mean_pairwise2(m, c, itr, state, nhalf)
        end
    end
    newcount = count + c
    # arranged carefully to reduce spurious overflow, using rationals to avoid promoting to Float64:
    mean += (m - mean) * (c // newcount)
    return isnothing(state) ? (mean, newcount, nothing) : (mean, newcount, state)
end

Test:

A16 = rand(Float16, 100_000_000);
itr16 = Iterators.take(A16, length(A16));

@show exact16 = sum(Float64, A16) / length(A16);
@show mean(A16) - exact16;
@show mean(itr16) - exact16;
@show streaming_mean(itr16) - exact16;
@show streaming_mean_kahan(itr16) - exact16;
@show streaming_mean_pairwise(itr16) - exact16;
@show streaming_mean_pairwise2(itr16) - exact16;

which gives:

exact16 = sum(Float64, A16) / length(A16) = 0.49976733576660154
mean(A16) - exact16 = NaN
mean(itr16) - exact16 = -0.49976733576660154
streaming_mean(itr16) - exact16 = 0.00902172673339846
streaming_mean_kahan(itr16) - exact16 = NaN
streaming_mean_pairwise(itr16) - exact16 = NaN
streaming_mean_pairwise2(itr16) - exact16 = -0.00025561701660153924

(Of course, if we know that the eltype is Float16 then we are probably better off just promoting to Float32 or Float64, but it is difficult to detect that reliably. For example, we could be taking the mean of an iterator of SVector{3,Float16} or something similar with Float16 components buried inside it. For example, we could force at least Float64 precision by initializing the accumulation with first_element / 1.0 rather than first_element / 1, but that would spoil exactness for Rational mean computations.)

For Float32 data, we pay a 15% performance price for the divisions compared to streaming_mean_pairwise (but still much better than streaming_mean):

A32 = rand(Float32, 100_000_000);
itr32 = Iterators.take(A32, length(A32))

@btime mean($A32);
@btime mean($itr32);
@btime streaming_mean($itr32);
@btime streaming_mean_kahan($itr32);
@btime streaming_mean_kahan4($itr32);
@btime streaming_mean_kahan128($itr32);
@btime streaming_mean_pairwise($itr32);
@btime streaming_mean_pairwise2($itr32);

gives

  20.334 ms (0 allocations: 0 bytes)
  107.383 ms (0 allocations: 0 bytes)
  482.192 ms (0 allocations: 0 bytes)
  405.375 ms (0 allocations: 0 bytes)
  106.306 ms (0 allocations: 0 bytes)
  99.855 ms (0 allocations: 0 bytes)
  117.612 ms (0 allocations: 0 bytes)
  134.374 ms (0 allocations: 0 bytes)

I'm not sure of a good way to avoid overflow in Kahan or blocked Kahan; it really seems to need recursive rescaling.

@PallHaraldsson
Copy link
Contributor

For Float32 data, we pay a 15% performance price for the divisions compared to streaming_mean_pairwise (but still much better than streaming_mean)

I'm not sure way, I thought you meant because m = s / c is slow (it might not be but useful to know the rest for other, including OP's here, code). So if that's because Float64 is used, then doing divisions in Float32 is faster, and it's also faster than with Float16, even on some CPUs that have hardware for that. For accuracy I thought promoting all to Float64 might be better, then convert back in the end, but if ok to only Float32 then might be better.

I tried to look up latency of division (and e.g. reciprocal) but it's hard to find (and differs by CPUs likely), but at least I know it's faster for 32-bit assembly division instruction rather then 64-bit. I've exploited that in a benchmark, and combining division (I see I independently discovered, i.e. the trick documented here at Agner):

https://www.agner.org/optimize/optimizing_cpp.pdf

Multiple divisions can be combined. For example:

you can do (32-bit) reciprocal with vdivps, and "low accuracy" (presumably faster) with "vrcp14ps + NR" (such seems not available for 64-bit, except of discontinued Knights Landing, maybe with discontinued AVX-512ER). Julia actually does neither for me on my computer, uses vdivss.

At least on AMD:
https://bluewaters.ncsa.illinois.edu/liferay-content/document-library/amd_1_24592_APM.pdf

VDIVSS and VDIVSD have no 256-bit form.

https://colfaxresearch.com/skl-avx512/

For Skylake, FMA has a latency of 4 cycles and reciprocal throughput of 0.5 cycle (i.e., 2 instructions per cycle)

Intriguing new for "Scalable family (formerly Skylake)":

2.10. AVX-512F: Embedded Rounding
Intel AVX-512 introduces the embedded rounding feature, which allows most arithmetic instructions to define a rounding mode applied to just this particular instruction.

3.1. AVX-512ER: Exponential, Reciprocal
AVX-512ER, the “exponential and reciprocal” module, provides high-accuracy base-2 exponential, reciprocal and reciprocal square root instructions. It is supported by KNL only.

[I think Knights Landing is discontinued.]

2.9. AVX-512F: Ternary Logic

A new operation, ternary logic, is introduced by AVX-512F to execute any bitwise logical functions on three operands in one instruction.

https://stackoverflow.com/questions/74806689/can-float16-data-type-save-compute-cycles-while-computing-transcendental-functio

vdivsh has worse latency than vdivss on Alder Lake.

Recent Intel CPUs (since Broadwell) use a radix-1024 divider to get division done in fewer steps. [,.] e.g. Skylake packed double-precision division (vdivpd ymm) has 16 times worse throughput than multiplication (vmulpd ymm), and it's worse in earlier CPUs with less powerful divide hardware. agner.org/optimize

https://stackoverflow.com/questions/4125033/floating-point-division-vs-floating-point-multiplication

Division has worse latency than multiplication or addition (or FMA) by a factor of 2 to 4 on modern x86 CPUs, and worse throughput by a factor of 6 to 401 (for a tight loop doing only division instead of only multiplication).

The divide / sqrt unit is not fully pipelined, for reasons explained in @NathanWhitehead's answer. The worst ratios are for 256b vectors, because (unlike other execution units) the divide unit is usually not full-width, so wide vectors have to be done in two halves. A not-fully-pipelined execution unit is so unusual that Intel CPUs have an arith.divider_active hardware performance counter to help you find code that bottlenecks on divider throughput instead of the usual front-end or execution port bottlenecks.

Found while looking up and it seems to have info in those "arith.divider_active hardware":
ocornut/imgui#4091

Catastrophic errors are possible from from adding just 3 (2?) numbers by naive means.
Is that a real worry over status quo/naive sum (assuming that is done).

@nalimilan
Copy link
Member

See #52397.

@stevengj
Copy link
Member

@nalimilan, yes, with #52397 we could in principle re-implement mean using mapreduce on pairs of (sum, count) and get the pairwise-summation accuracy. It wouldn't get the recursive rescaling (to avoid overflow in Float16) of #52365 (comment) though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants