Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance optimizations of various statistics #201

Open
JCGoran opened this issue Jul 18, 2024 · 2 comments · May be fixed by #204
Open

Performance optimizations of various statistics #201

JCGoran opened this issue Jul 18, 2024 · 2 comments · May be fixed by #204
Labels
enhancement New feature or request
Milestone

Comments

@JCGoran
Copy link
Contributor

JCGoran commented Jul 18, 2024

Is your feature request related to a problem? Please describe.

Not really a problem, more like a potential optimization (I haven't worked out the details to see if it actually works).

So, if I understood how the algorithm works under the hood, we basically move the points in the dataset, one at a time, then, at each iteration, compute the mean, the standard deviation, and the correlation coefficient of this new dataset.
One thing that stands out performance-wise is that we currently use all of the points to compute the statistics at each step, which seems a bit wasteful.

Describe the solution you'd like

Instead of computing the statistics of the whole dataset, which requires at least iterating over all $n$ points (even more ops for the stdev/corrcoef), we can use the fact that we are only moving one point, and rewrite the new statistics in terms of old ones + a perturbation. For instance, for the new value of the mean statistic, we get:

$$ \langle X' \rangle = \langle X \rangle + \frac{\delta}{n} $$

where $\delta = x'_i - x_i$, and $n$ is the number of points in the dataset. Analogous formulas can be derived for the variance, which is the square of the stdev anyway (it's possible some tweaking of the denominators is needed when taking into account the Bessel correction):

$$ \text{Var}(X') = \text{Var}(X) + 2 \frac{\delta}{n}(x_i - \langle X \rangle) + \frac{\delta^2}{n} - \frac{\delta^2}{n^2} $$

and probably for the correlation coefficient (or better, its square) as well. This would allow us to compute all of the statistics in basically $O(1)$ time, instead of $O(n)$ or larger.

There's at least one problem which I haven't worked out yet: is this numerically stable? Since numerical accuracy is paramount for the code to work properly, if the above has a large loss of accuracy, then it's not very useful, but if it's stable, it could be worthwhile to explore implementing it.

Some references that could be of use (regarding both the computation and numerical stability):

Describe alternatives you've considered

None.

Additional context

None.

@JCGoran JCGoran linked a pull request Jul 21, 2024 that will close this issue
3 tasks
@stefmolin
Copy link
Owner

Make sure you account for the plan to add the median (see #181).

@stefmolin stefmolin added the enhancement New feature or request label Jul 21, 2024
@stefmolin stefmolin added this to the 0.4.0 milestone Jul 21, 2024
@JCGoran
Copy link
Contributor Author

JCGoran commented Jul 23, 2024

After some considerations, the median seems to be implementable as follows:

  1. sort the input data
  2. split the now sorted data into two AVL trees, one of size n // 2 (lower part), and the other of size n - n // 2 (higher part)
  3. after the above step, the median is then either min(higher part) (case n odd) or (max(lower part) + max(higher part)) / 2 (case n even)
  4. doing the replacement $x_i \mapsto x'_i$ is equivalent to removing $x_i$ from one of the two trees, followed by inserting $x'_i$ into one of the two trees (basically a bunch of if statements, depending on where we're removing/inserting)
  5. To maintain the balance of the trees (i.e. each tree has the same number of elements (case n even), or the higher part has 1 extra element (case n odd)), we occasionally need to either remove the largest element from the lower part and insert it in the higher part, or remove the smallest element from the higher part and insert it in the lower part
  6. after the rebalancing is done, go to step 3 to get the median

An AVL tree does all of the operations above in $O(\log n)$, so we can find the median of the "perturbed" dataset in $O(\log n)$ (numpy.median runs in $O(n)$ thanks to using something like quickselect, as opposed to the naive "sort and find" which runs in $O(n \log n)$ ).

I used the timeit module to see whether this actually works, and the results are encouraging (tried out on an array w/ 5M elements, 100 repeats):

  • avg(np): 7.2 ms
  • avg(avl): 0.087 ms

Note that I'm not counting the initial sorting in the performance (which is $O(n \log n)$ ) since we only need to do it once, at the start of the simulation.

Maybe I'm overcomplicating things, but it seems to me that we need a tree-like structure for this; at first I've considered using the heapq builtin module to use the 2-heap algorithm, but it doesn't support removal of arbitrary elements, only the top-most one, so I've opted for using an AVL tree instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants