Last active
April 26, 2022 07:41
-
-
Save Ending2015a/aed404892de353083c148a74768c3445 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --- built in --- | |
import abc | |
import math | |
# --- 3rd party --- | |
import numpy as np | |
class SegmentTree(metaclass=abc.ABCMeta): | |
def __init__(self, size: int): | |
'''An implementation of segment tree used to efficiently O(logN) | |
compute the sum of a query range [start, end) | |
Args: | |
size (int): Number of elements. | |
''' | |
assert isinstance(size, int) and size > 0 | |
base = 1<<(size-1).bit_length() | |
self._size = size | |
self._base = base | |
self._value = np.zeros([base * 2], dtype=np.float64) | |
def __getitem__(self, key: np.ndarray): | |
# formalize indices | |
if isinstance(key, (int, slice)): | |
key = np.asarray(range(self._size)[key], dtype=np.int64) | |
else: | |
key = np.asarray(key, dtype=np.int64) | |
key = key % self._size + self._base | |
return self._value[key] | |
def __setitem__(self, key: np.ndarray, value: np.ndarray): | |
self.update(key, value) | |
def update(self, key: np.ndarray, value: np.ndarray): | |
'''Update elements' values''' | |
# formalize indices | |
if isinstance(key, (int, slice)): | |
key = np.asarray(range(self._size)[key], dtype=np.int64) | |
else: | |
key = np.asarray(key, dtype=np.int64) | |
key = key % self._size + self._base | |
key = key.flatten() | |
value = np.asarray(value, dtype=np.float64).flatten() | |
# set values | |
self._value[key] = value | |
# update tree (all keys have the same depth) | |
while key[0] > 1: | |
self._value[key>>1] = self._value[key] + self._value[key^1] | |
key >>= 1 | |
def sum(self, start: int=None, end: int=None): | |
'''Compute the sum of the given range [start, end)''' | |
if (start == None) and (end == None): | |
# shortcut | |
return self._value[1] | |
start, end, _ = slice(start, end).indices(self._size) | |
start += self._base | |
end += self._base | |
res = 0.0 | |
while start < end: | |
if start & 1: | |
res += self._value[start] | |
if end & 1: | |
res += self._value[end-1] | |
start = (start+1) >> 1 | |
end = end >> 1 | |
return res | |
def index(self, value: np.ndarray): | |
'''Return the largest index such that | |
value[0:index+1].sum() < value | |
''' | |
assert np.min(value) >= 0.0 | |
assert np.max(value) < self._value[1] | |
# if input is a scalar, return should be a scalar too. | |
one_value = np.isscalar(value) | |
# convert to 1D array | |
value = np.asarray(value, dtype=np.float64) | |
orig_shape = value.shape | |
value = value.flatten() | |
inds = np.ones_like(value, dtype=np.int64) | |
# find inds (all inds have the same depth) | |
while inds[0] < self._base: | |
inds <<= 1 | |
lsum = self._value[inds] | |
d = lsum < value | |
value -= lsum * d | |
inds += d | |
inds -= self._base | |
inds = inds.reshape(orig_shape) | |
return inds.item() if one_value else inds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment