Last active
October 23, 2023 10:47
-
-
Save eladn/e9cb341a912af2f761b06c97d5ff9e80 to your computer and use it in GitHub Desktop.
Python mean & variance moments accumulator; numerically stable; based on Welford's algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__author__ = "Elad Nachmias" | |
__email__ = "eladnah@gmail.com" | |
__date__ = "2023-10-19" | |
import copy | |
import dataclasses | |
from typing import Tuple, Optional | |
import numpy as np | |
@dataclasses.dataclass(frozen=True) | |
class AccumulatedMomentStatistics: | |
nr_values: int | |
mean: float | |
variance: float | |
std: float | |
class OnlineMomentsAccumulator: | |
""" | |
Used for maintaining accumulators for mean and variance of an online stream of real-valued scalars. | |
Main usage if calling `push(value)` to append new values to the accumulation (update in O(1) complexity). | |
The current statistics values (mean / variance / std) are accessible at any stage in O(1) complexity. | |
For the variance, the unbiased sample-estimator is maintained. | |
Internally, the accumulator only stores 2 float scalars and an integer (the entire series is not stored). | |
""" | |
def __init__(self, approximate_moving_window_size: Optional[int] = None): | |
self._mean: float = np.nan | |
self._variance: float = np.nan | |
self._nr_points: int = 0 | |
assert approximate_moving_window_size is None or approximate_moving_window_size >= 1 | |
self._approximate_moving_window_size = approximate_moving_window_size | |
def clone(self) -> 'OnlineMomentsAccumulator': | |
return copy.deepcopy(self) | |
@property | |
def mean(self) -> float: | |
return self._mean | |
@property | |
def variance(self) -> float: | |
""" | |
Unbiased sample-estimator for variance. | |
""" | |
return self._variance | |
@property | |
def std(self) -> float: | |
return np.sqrt(self.variance) | |
@property | |
def nr_points(self) -> int: | |
assert self._approximate_moving_window_size is None or self._nr_points <= self._approximate_moving_window_size | |
return self._nr_points | |
def __len__(self) -> int: | |
return self.nr_points | |
def merge(self, other: 'OnlineMomentsAccumulator', is_other_newer: bool = True): | |
""" | |
Add the statistics from given `other` into `self`. Modify `self` in-place. | |
""" | |
if self._approximate_moving_window_size is None: | |
n1, n2 = self.nr_points, other.nr_points | |
else: | |
if is_other_newer: | |
# We assume the data in `self` is older than the data in `other` which is merged into here. | |
n1, n2 = max(0, self._approximate_moving_window_size - other.nr_points), other.nr_points | |
else: | |
n1, n2 = self.nr_points, other.nr_points | |
if n1 + n2 > self._approximate_moving_window_size: | |
half_win = int(np.ceil(self._approximate_moving_window_size / 2)) | |
n1 = min(n1, half_win + max(0, half_win - n2)) | |
n2 = min(n2, self._approximate_moving_window_size - n1) | |
n12 = n1 + n2 | |
var1, var2 = self.variance, other.variance | |
mu1, mu2 = self.mean, other.mean | |
if n2 == 0: | |
return | |
if n2 == 1: | |
self.push(mu2) | |
return | |
if n1 == 0: | |
joint_mean = mu2 | |
joint_var = var2 | |
else: | |
joint_mean = (n1 / n12) * mu1 + (n2 / n12) * mu2 | |
# Variance estimation is based on sample-based pooled variance for 2 disjoint sets of samples. | |
# link: https://en.wikipedia.org/wiki/Pooled_variance#Sample-based_statistics | |
assert n1 > 0 | |
assert n2 > 1 | |
assert n12 > 1 | |
joint_var = \ | |
((n1 - 1) / (n12 - 1)) * var1 + ((n2 - 1) / (n12 - 1)) * var2 + \ | |
(n1 * mu1 ** 2 + n2 * mu2 ** 2 - n12 * joint_mean ** 2) / (n12 - 1) | |
if np.isclose(joint_var, 0.) or joint_var < 0.: | |
joint_var = 0. | |
if self._approximate_moving_window_size is None: | |
self._nr_points += n2 | |
else: | |
self._nr_points = min(self._nr_points + n2, self._approximate_moving_window_size) | |
self._mean = joint_mean | |
self._variance = joint_var | |
def __add__(self, other: 'OnlineMomentsAccumulator') -> 'OnlineMomentsAccumulator': | |
new_acc = self.clone() | |
new_acc.merge(other, is_other_newer=False) | |
return new_acc | |
def __iadd__(self, other: 'OnlineMomentsAccumulator') -> 'OnlineMomentsAccumulator': | |
self.merge(other, is_other_newer=False) | |
return self | |
def push(self, value: float) -> None: | |
""" | |
Update the accumulated statistics by adding a new value to it. | |
""" | |
assert np.isfinite(value) | |
if self._nr_points == 0: | |
self._mean = value | |
self._variance = 0.0 | |
self._nr_points = 1 | |
else: | |
self._mean, self._variance = self._perform_update_step( | |
value=value, last_mean=self._mean, | |
last_variance=self._variance, last_nr_points=self._nr_points) | |
if self._approximate_moving_window_size is None: | |
self._nr_points += 1 | |
else: | |
self._nr_points = min(self._nr_points + 1, self._approximate_moving_window_size) | |
def pop(self, value: float) -> None: | |
""" | |
Update the accumulated statistics by *removing* a value that has been already added to it before. | |
""" | |
assert self.nr_points > 0 | |
if self.nr_points == 1: | |
self._mean = np.nan | |
self._variance = np.nan | |
self._nr_points = 0 | |
else: | |
self._mean, self._variance = self._perform_remove_step( | |
value=value, last_mean=self._mean, | |
last_variance=self._variance, last_nr_points=self.nr_points) | |
self._nr_points -= 1 | |
def get_stats(self) -> AccumulatedMomentStatistics: | |
return AccumulatedMomentStatistics( | |
nr_values=self.nr_points, mean=self.mean, variance=self.variance, std=self.std) | |
@classmethod | |
def _perform_update_step( | |
cls, value: float, last_mean: float, last_variance: float, last_nr_points: int) -> Tuple[float, float]: | |
""" | |
Update the accumulated statistics by adding a new value to it. | |
""" | |
dx = value - last_mean | |
next_mean = last_mean + dx / (last_nr_points + 1) | |
# This logic is based on Welford's online algorithm. Link: | |
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm | |
# This update rule is for maintaining a population-estimator (divide by `n`): | |
# next_variance = (last_nr_points / (last_nr_points + 1)) * (last_variance + dx ** 2 / (last_nr_points + 1)) | |
# This update rule is for maintaining an unbiased sample-estimator (divide by `n-1`): | |
next_variance = \ | |
last_variance + ((last_nr_points / (last_nr_points + 1)) * dx ** 2 - last_variance) / last_nr_points | |
if np.isclose(next_variance, 0.) or next_variance < 0.: | |
next_variance = 0. | |
return next_mean, next_variance | |
@classmethod | |
def _perform_remove_step( | |
cls, value: float, last_mean: float, last_variance: float, last_nr_points: int) -> Tuple[float, float]: | |
""" | |
Update the accumulated statistics by *removing* a value that has been already added to it before. | |
""" | |
if last_nr_points == 1: | |
return np.nan, np.nan | |
dx = value - last_mean | |
# mu_(n+1) = mu_(n) - (x - mu_(n))/(n-1) | |
next_mean = last_mean - dx / (last_nr_points - 1) | |
if last_nr_points > 2: | |
# This update rule is for maintaining an unbiased sample-estimator (divide by `n-1`): | |
dx2 = value - next_mean | |
next_variance = \ | |
last_variance - \ | |
(((last_nr_points - 1) / last_nr_points) * dx2 ** 2 - last_variance) / (last_nr_points - 2) | |
if np.isclose(next_variance, 0.) or next_variance < 0.: | |
next_variance = 0. | |
else: | |
next_variance = 0. | |
return next_mean, next_variance | |
def test_online_mean_variance_accumulators_push_only(): | |
rng = np.random.RandomState(seed=0) | |
values_array = rng.random(size=100) | |
mva = OnlineMomentsAccumulator() | |
mva_means = [] | |
mva_variances = [] | |
for item in values_array: | |
mva.push(item) | |
mva_means.append(mva.mean) | |
mva_variances.append(mva.variance) | |
mva_means, mva_variances = np.array(mva_means), np.array(mva_variances) | |
np_means = np.array([np.mean(values_array[:i + 1]) for i in range(len(values_array))]) | |
np_variances = np.array([np.var(np.array(values_array[:i + 1]), ddof=1) for i in range(len(values_array))]) | |
assert np.allclose(mva_means, np_means) | |
# We exclude the case of a single-value array from the comparison of the variance. Numpy returns a variance of | |
# `NaN` for an array with a single item (because of the division by (n-1) in the raw variance formula of degree of | |
# freedom = 1, which causes a division by zero if applied for n=1), while we intentionally choose the definition | |
# of var=0.0 for this case (n=1). | |
assert np.allclose(mva_variances[1:], np_variances[1:]) | |
assert np.isclose(mva_variances[0], 0.) | |
def test_online_mean_variance_accumulators_randomly_pop_while_pushing(): | |
rng = np.random.RandomState(seed=0) | |
mva = OnlineMomentsAccumulator() | |
total_nr_items = 100 | |
arr = rng.random(size=total_nr_items) | |
pushed_mask = np.full(shape=arr.shape, fill_value=False, dtype=bool) | |
popped_mask = np.full(shape=arr.shape, fill_value=False, dtype=bool) | |
while True: | |
are_there_items_left_to_push = not np.all(pushed_mask) | |
pushed_but_not_popped_mask = pushed_mask & ~popped_mask | |
_assert_same_stats_as_arr(mva=mva, arr=arr[pushed_but_not_popped_mask]) | |
are_there_items_waiting_to_be_popped = np.any(pushed_but_not_popped_mask) | |
if not are_there_items_left_to_push and not are_there_items_waiting_to_be_popped: | |
break | |
if are_there_items_left_to_push and are_there_items_waiting_to_be_popped: | |
operation = rng.choice(['push', 'pop'], p=[0.75, 0.25]) | |
elif are_there_items_left_to_push: | |
operation = 'push' | |
else: | |
assert are_there_items_waiting_to_be_popped | |
operation = 'pop' | |
if operation == 'push': | |
next_item_idx_to_push = np.argmax(~pushed_mask) | |
mva.push(arr[next_item_idx_to_push]) | |
pushed_mask[next_item_idx_to_push] = True | |
elif operation == 'pop': | |
next_item_idx_to_pop = np.argmax(pushed_but_not_popped_mask) | |
mva.pop(arr[next_item_idx_to_pop]) | |
popped_mask[next_item_idx_to_pop] = True | |
assert np.sum(popped_mask) == total_nr_items | |
def test_online_mean_variance_accumulators_merge(): | |
rng = np.random.RandomState(seed=0) | |
values_array1 = rng.random(size=100) | |
values_array2 = rng.random(size=150) # different size to ensure it still works for asymmetric |sets| | |
mva1 = OnlineMomentsAccumulator() | |
mva2 = OnlineMomentsAccumulator() | |
for item in values_array1: | |
mva1.push(item) | |
for item in values_array2: | |
mva2.push(item) | |
_assert_same_stats_as_arr(mva1, values_array1) | |
_assert_same_stats_as_arr(mva2, values_array2) | |
merged_mva = mva1 + mva2 | |
merged_array = np.concatenate([values_array1, values_array2]) | |
_assert_same_stats_as_arr(merged_mva, merged_array) | |
def _assert_same_stats_as_arr(mva: OnlineMomentsAccumulator, arr: np.ndarray): | |
assert len(mva) == len(arr) | |
if len(mva) == 0: | |
return | |
assert np.isclose(mva.mean, np.mean(arr)) | |
if len(mva) == 1: | |
assert np.isclose(mva.variance, 0.) | |
else: | |
assert np.isclose(mva.variance, np.var(arr, ddof=1)) | |
if __name__ == '__main__': | |
test_online_mean_variance_accumulators_push_only() | |
test_online_mean_variance_accumulators_randomly_pop_while_pushing() | |
test_online_mean_variance_accumulators_merge() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment