Skip to content

Instantly share code, notes, and snippets.

@eladn
Last active October 23, 2023 10:47
Show Gist options
  • Save eladn/e9cb341a912af2f761b06c97d5ff9e80 to your computer and use it in GitHub Desktop.
Save eladn/e9cb341a912af2f761b06c97d5ff9e80 to your computer and use it in GitHub Desktop.
Python mean & variance moments accumulator; numerically stable; based on Welford's algorithm
__author__ = "Elad Nachmias"
__email__ = "eladnah@gmail.com"
__date__ = "2023-10-19"
import copy
import dataclasses
from typing import Tuple, Optional
import numpy as np
@dataclasses.dataclass(frozen=True)
class AccumulatedMomentStatistics:
nr_values: int
mean: float
variance: float
std: float
class OnlineMomentsAccumulator:
"""
Used for maintaining accumulators for mean and variance of an online stream of real-valued scalars.
Main usage if calling `push(value)` to append new values to the accumulation (update in O(1) complexity).
The current statistics values (mean / variance / std) are accessible at any stage in O(1) complexity.
For the variance, the unbiased sample-estimator is maintained.
Internally, the accumulator only stores 2 float scalars and an integer (the entire series is not stored).
"""
def __init__(self, approximate_moving_window_size: Optional[int] = None):
self._mean: float = np.nan
self._variance: float = np.nan
self._nr_points: int = 0
assert approximate_moving_window_size is None or approximate_moving_window_size >= 1
self._approximate_moving_window_size = approximate_moving_window_size
def clone(self) -> 'OnlineMomentsAccumulator':
return copy.deepcopy(self)
@property
def mean(self) -> float:
return self._mean
@property
def variance(self) -> float:
"""
Unbiased sample-estimator for variance.
"""
return self._variance
@property
def std(self) -> float:
return np.sqrt(self.variance)
@property
def nr_points(self) -> int:
assert self._approximate_moving_window_size is None or self._nr_points <= self._approximate_moving_window_size
return self._nr_points
def __len__(self) -> int:
return self.nr_points
def merge(self, other: 'OnlineMomentsAccumulator', is_other_newer: bool = True):
"""
Add the statistics from given `other` into `self`. Modify `self` in-place.
"""
if self._approximate_moving_window_size is None:
n1, n2 = self.nr_points, other.nr_points
else:
if is_other_newer:
# We assume the data in `self` is older than the data in `other` which is merged into here.
n1, n2 = max(0, self._approximate_moving_window_size - other.nr_points), other.nr_points
else:
n1, n2 = self.nr_points, other.nr_points
if n1 + n2 > self._approximate_moving_window_size:
half_win = int(np.ceil(self._approximate_moving_window_size / 2))
n1 = min(n1, half_win + max(0, half_win - n2))
n2 = min(n2, self._approximate_moving_window_size - n1)
n12 = n1 + n2
var1, var2 = self.variance, other.variance
mu1, mu2 = self.mean, other.mean
if n2 == 0:
return
if n2 == 1:
self.push(mu2)
return
if n1 == 0:
joint_mean = mu2
joint_var = var2
else:
joint_mean = (n1 / n12) * mu1 + (n2 / n12) * mu2
# Variance estimation is based on sample-based pooled variance for 2 disjoint sets of samples.
# link: https://en.wikipedia.org/wiki/Pooled_variance#Sample-based_statistics
assert n1 > 0
assert n2 > 1
assert n12 > 1
joint_var = \
((n1 - 1) / (n12 - 1)) * var1 + ((n2 - 1) / (n12 - 1)) * var2 + \
(n1 * mu1 ** 2 + n2 * mu2 ** 2 - n12 * joint_mean ** 2) / (n12 - 1)
if np.isclose(joint_var, 0.) or joint_var < 0.:
joint_var = 0.
if self._approximate_moving_window_size is None:
self._nr_points += n2
else:
self._nr_points = min(self._nr_points + n2, self._approximate_moving_window_size)
self._mean = joint_mean
self._variance = joint_var
def __add__(self, other: 'OnlineMomentsAccumulator') -> 'OnlineMomentsAccumulator':
new_acc = self.clone()
new_acc.merge(other, is_other_newer=False)
return new_acc
def __iadd__(self, other: 'OnlineMomentsAccumulator') -> 'OnlineMomentsAccumulator':
self.merge(other, is_other_newer=False)
return self
def push(self, value: float) -> None:
"""
Update the accumulated statistics by adding a new value to it.
"""
assert np.isfinite(value)
if self._nr_points == 0:
self._mean = value
self._variance = 0.0
self._nr_points = 1
else:
self._mean, self._variance = self._perform_update_step(
value=value, last_mean=self._mean,
last_variance=self._variance, last_nr_points=self._nr_points)
if self._approximate_moving_window_size is None:
self._nr_points += 1
else:
self._nr_points = min(self._nr_points + 1, self._approximate_moving_window_size)
def pop(self, value: float) -> None:
"""
Update the accumulated statistics by *removing* a value that has been already added to it before.
"""
assert self.nr_points > 0
if self.nr_points == 1:
self._mean = np.nan
self._variance = np.nan
self._nr_points = 0
else:
self._mean, self._variance = self._perform_remove_step(
value=value, last_mean=self._mean,
last_variance=self._variance, last_nr_points=self.nr_points)
self._nr_points -= 1
def get_stats(self) -> AccumulatedMomentStatistics:
return AccumulatedMomentStatistics(
nr_values=self.nr_points, mean=self.mean, variance=self.variance, std=self.std)
@classmethod
def _perform_update_step(
cls, value: float, last_mean: float, last_variance: float, last_nr_points: int) -> Tuple[float, float]:
"""
Update the accumulated statistics by adding a new value to it.
"""
dx = value - last_mean
next_mean = last_mean + dx / (last_nr_points + 1)
# This logic is based on Welford's online algorithm. Link:
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
# This update rule is for maintaining a population-estimator (divide by `n`):
# next_variance = (last_nr_points / (last_nr_points + 1)) * (last_variance + dx ** 2 / (last_nr_points + 1))
# This update rule is for maintaining an unbiased sample-estimator (divide by `n-1`):
next_variance = \
last_variance + ((last_nr_points / (last_nr_points + 1)) * dx ** 2 - last_variance) / last_nr_points
if np.isclose(next_variance, 0.) or next_variance < 0.:
next_variance = 0.
return next_mean, next_variance
@classmethod
def _perform_remove_step(
cls, value: float, last_mean: float, last_variance: float, last_nr_points: int) -> Tuple[float, float]:
"""
Update the accumulated statistics by *removing* a value that has been already added to it before.
"""
if last_nr_points == 1:
return np.nan, np.nan
dx = value - last_mean
# mu_(n+1) = mu_(n) - (x - mu_(n))/(n-1)
next_mean = last_mean - dx / (last_nr_points - 1)
if last_nr_points > 2:
# This update rule is for maintaining an unbiased sample-estimator (divide by `n-1`):
dx2 = value - next_mean
next_variance = \
last_variance - \
(((last_nr_points - 1) / last_nr_points) * dx2 ** 2 - last_variance) / (last_nr_points - 2)
if np.isclose(next_variance, 0.) or next_variance < 0.:
next_variance = 0.
else:
next_variance = 0.
return next_mean, next_variance
def test_online_mean_variance_accumulators_push_only():
rng = np.random.RandomState(seed=0)
values_array = rng.random(size=100)
mva = OnlineMomentsAccumulator()
mva_means = []
mva_variances = []
for item in values_array:
mva.push(item)
mva_means.append(mva.mean)
mva_variances.append(mva.variance)
mva_means, mva_variances = np.array(mva_means), np.array(mva_variances)
np_means = np.array([np.mean(values_array[:i + 1]) for i in range(len(values_array))])
np_variances = np.array([np.var(np.array(values_array[:i + 1]), ddof=1) for i in range(len(values_array))])
assert np.allclose(mva_means, np_means)
# We exclude the case of a single-value array from the comparison of the variance. Numpy returns a variance of
# `NaN` for an array with a single item (because of the division by (n-1) in the raw variance formula of degree of
# freedom = 1, which causes a division by zero if applied for n=1), while we intentionally choose the definition
# of var=0.0 for this case (n=1).
assert np.allclose(mva_variances[1:], np_variances[1:])
assert np.isclose(mva_variances[0], 0.)
def test_online_mean_variance_accumulators_randomly_pop_while_pushing():
rng = np.random.RandomState(seed=0)
mva = OnlineMomentsAccumulator()
total_nr_items = 100
arr = rng.random(size=total_nr_items)
pushed_mask = np.full(shape=arr.shape, fill_value=False, dtype=bool)
popped_mask = np.full(shape=arr.shape, fill_value=False, dtype=bool)
while True:
are_there_items_left_to_push = not np.all(pushed_mask)
pushed_but_not_popped_mask = pushed_mask & ~popped_mask
_assert_same_stats_as_arr(mva=mva, arr=arr[pushed_but_not_popped_mask])
are_there_items_waiting_to_be_popped = np.any(pushed_but_not_popped_mask)
if not are_there_items_left_to_push and not are_there_items_waiting_to_be_popped:
break
if are_there_items_left_to_push and are_there_items_waiting_to_be_popped:
operation = rng.choice(['push', 'pop'], p=[0.75, 0.25])
elif are_there_items_left_to_push:
operation = 'push'
else:
assert are_there_items_waiting_to_be_popped
operation = 'pop'
if operation == 'push':
next_item_idx_to_push = np.argmax(~pushed_mask)
mva.push(arr[next_item_idx_to_push])
pushed_mask[next_item_idx_to_push] = True
elif operation == 'pop':
next_item_idx_to_pop = np.argmax(pushed_but_not_popped_mask)
mva.pop(arr[next_item_idx_to_pop])
popped_mask[next_item_idx_to_pop] = True
assert np.sum(popped_mask) == total_nr_items
def test_online_mean_variance_accumulators_merge():
rng = np.random.RandomState(seed=0)
values_array1 = rng.random(size=100)
values_array2 = rng.random(size=150) # different size to ensure it still works for asymmetric |sets|
mva1 = OnlineMomentsAccumulator()
mva2 = OnlineMomentsAccumulator()
for item in values_array1:
mva1.push(item)
for item in values_array2:
mva2.push(item)
_assert_same_stats_as_arr(mva1, values_array1)
_assert_same_stats_as_arr(mva2, values_array2)
merged_mva = mva1 + mva2
merged_array = np.concatenate([values_array1, values_array2])
_assert_same_stats_as_arr(merged_mva, merged_array)
def _assert_same_stats_as_arr(mva: OnlineMomentsAccumulator, arr: np.ndarray):
assert len(mva) == len(arr)
if len(mva) == 0:
return
assert np.isclose(mva.mean, np.mean(arr))
if len(mva) == 1:
assert np.isclose(mva.variance, 0.)
else:
assert np.isclose(mva.variance, np.var(arr, ddof=1))
if __name__ == '__main__':
test_online_mean_variance_accumulators_push_only()
test_online_mean_variance_accumulators_randomly_pop_while_pushing()
test_online_mean_variance_accumulators_merge()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment